Skip to content

Commit 6c7da04

Browse files
authored
feat(save): stub serialization toolkit (class/lazy/module stubs) (#9896)
Expands the current stubbing mechanism for larger coverage for various python instances (classes, lambdas, modules, and module imports) Introduces: - UnhashableStub - ClassStub Additionally: - Handles module imports explicitly to prevent their serialization to disk - and replaces pytorch tensor Pickled objects with a `.pt` loader
1 parent 54f83a2 commit 6c7da04

17 files changed

Lines changed: 1361 additions & 40 deletions

marimo/_runtime/exceptions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
if TYPE_CHECKING:
77
from marimo._runtime.dataflow import DirectedGraph
8+
from marimo._types.ids import CellId_t
89

910

1011
class MarimoRuntimeException(BaseException):
@@ -26,6 +27,39 @@ def __init__(self, ref: str, name_error: NameError | None = None) -> None:
2627
self.name_error = name_error
2728

2829

30+
class MarimoCancelCellError(BaseException):
31+
"""Soft-cancel signal raised by a lifecycle.
32+
33+
Subclasses BaseException (not Exception) so user-code `except Exception`
34+
blocks don't swallow the control-flow signal.
35+
"""
36+
37+
cells_to_rerun: set[CellId_t]
38+
39+
def __init__(
40+
self,
41+
*args: object,
42+
cells_to_rerun: set[CellId_t] | None = None,
43+
) -> None:
44+
super().__init__(*args)
45+
self.cells_to_rerun = cells_to_rerun or set()
46+
47+
48+
class MarimoUnhashableCacheError(MarimoCancelCellError):
49+
"""Raised when cell-level caching encounters a value that cannot be
50+
hashed/serialized for cache restoration."""
51+
52+
def __init__(
53+
self,
54+
cells_to_rerun: set[CellId_t],
55+
variables: list[str],
56+
error_details: str,
57+
) -> None:
58+
super().__init__(error_details, cells_to_rerun=cells_to_rerun)
59+
self.variables = variables
60+
self.error_details = error_details
61+
62+
2963
def unwrap_user_exception(
3064
exc: MarimoRuntimeException,
3165
graph: DirectedGraph | None = None,

marimo/_save/cache.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import abc
55
import inspect
66
import re
7+
import textwrap
78
from collections import namedtuple
89
from dataclasses import dataclass, field
910
from typing import TYPE_CHECKING, Any, Literal, get_args
@@ -14,13 +15,15 @@
1415
from marimo._runtime.state import SetFunctor
1516
from marimo._save.stubs import (
1617
CUSTOM_STUBS,
18+
ClassStub,
1719
CustomStub,
1820
FunctionStub,
1921
ModuleStub,
2022
ReferenceStub,
2123
UIElementStub,
2224
maybe_register_stub,
2325
)
26+
from marimo._save.stubs.lazy_stub import UnhashableStub
2427

2528
# Many assertions are for typing and should always pass. This message is a
2629
# catch all to motive users to report if something does fail.
@@ -114,6 +117,24 @@ def dispose(self, context: RuntimeContext, deletion: bool) -> bool: # noqa: ARG
114117
)
115118

116119

120+
def _source_refs(code: str) -> set[Name]:
121+
"""Free references of a captured class/function source block.
122+
123+
Reuses the cell `ScopedVisitor` so the notion of "what this def needs"
124+
matches marimo's dataflow rather than a bespoke parser.
125+
"""
126+
from marimo._ast.parse import ast_parse
127+
from marimo._ast.visitor import ScopedVisitor
128+
129+
try:
130+
tree = ast_parse(textwrap.dedent(code))
131+
except SyntaxError:
132+
return set()
133+
visitor = ScopedVisitor("cache_restore")
134+
visitor.visit(tree)
135+
return set(visitor.refs)
136+
137+
117138
# BaseException because "raise _ as e" is utilized.
118139
class CacheException(BaseException):
119140
pass
@@ -136,7 +157,7 @@ class Cache:
136157
def restore(self, scope: dict[str, Any]) -> None:
137158
"""Restores values from cache, into scope."""
138159
memo: dict[int, Any] = {} # Track processed objects to handle cycles
139-
for var, lookup in self.contextual_defs():
160+
for var, lookup in self._restore_order(self.contextual_defs()):
140161
value = self.defs.get(var, None)
141162
scope[lookup] = self._restore_from_stub_if_needed(
142163
value, scope, memo
@@ -167,6 +188,44 @@ def restore(self, scope: dict[str, Any]) -> None:
167188
f"({type(ref)}:{ref})."
168189
)
169190

191+
def _restore_deps(self, value: Any) -> set[Name]:
192+
"""Cross-def names *value* needs before it can be restored.
193+
194+
- A re-exec'd class/function needs the defs it references at
195+
definition time (bases, decorators, class-body calls).
196+
- A pickled instance of a cell-defined class needs that class
197+
materialized first (tagged via `requires`).
198+
"""
199+
if isinstance(value, (ClassStub, FunctionStub)):
200+
return _source_refs(value.code)
201+
requires = getattr(value, "requires", "")
202+
return {requires} if requires else set()
203+
204+
def _restore_order(
205+
self, contextual_defs: dict[tuple[Name, Name], Any]
206+
) -> list[tuple[Name, Name]]:
207+
"""Order `(var, lookup)` pairs so each def's cross-def dependencies
208+
restore first (depth-first; tolerant of cycles via `seen`)."""
209+
lookups = {var: lookup for var, lookup in contextual_defs}
210+
deps = {
211+
var: self._restore_deps(self.defs.get(var)) & lookups.keys()
212+
for var in lookups
213+
}
214+
order: list[tuple[Name, Name]] = []
215+
seen: set[Name] = set()
216+
217+
def visit(var: Name) -> None:
218+
if var in seen:
219+
return
220+
seen.add(var)
221+
for dep in deps[var]:
222+
visit(dep)
223+
order.append((var, lookups[var]))
224+
225+
for var in lookups:
226+
visit(var)
227+
return order
228+
170229
def _restore_from_stub_if_needed(
171230
self,
172231
value: Any,
@@ -223,6 +282,16 @@ def _restore_from_stub_if_needed(
223282
result = value
224283
elif isinstance(value, ReferenceStub):
225284
result = value.load(scope)
285+
elif isinstance(value, ClassStub):
286+
# Re-exec the captured source into the cell namespace so the
287+
# name rebinds to a live class (not the stub). Must run before
288+
# any pickle blob referencing the class deserializes, so the
289+
# class is resolvable as `__main__.<name>` in `scope`.
290+
result = value.load(scope)
291+
elif isinstance(value, UnhashableStub):
292+
# Marker for a def whose value couldn't be serialized. Place it
293+
# in scope as-is.
294+
result = value
226295
elif isinstance(value, CustomStub):
227296
# CustomStub is a placeholder for a custom type, which cannot be
228297
# restored directly.
@@ -315,8 +384,28 @@ def _convert_to_stub_if_needed(
315384

316385
if inspect.ismodule(value):
317386
result = ModuleStub(value)
318-
elif inspect.isfunction(value):
387+
elif (
388+
inspect.isfunction(value)
389+
and value.__name__ != "<lambda>"
390+
and getattr(value, "__module__", "__main__") == "__main__"
391+
):
392+
# NB. Lambdas can't round-trip via FunctionStub: inspect.getsource
393+
# returns the line *containing* the lambda (e.g. "return model,
394+
# lambda inp: model(inp)"), which fails to compile as a module.
319395
result = FunctionStub(value)
396+
elif (
397+
inspect.isclass(value)
398+
and getattr(value, "__module__", None) == "__main__"
399+
):
400+
# Attempt to capture classes by source so the loader can rebuild
401+
# them in the cell namespace before pickle blobs that reference them
402+
# deserialize. Pass the executing cell's source filename as a hint
403+
# so attribute-only classes (no method code object to read a
404+
# filename from) still resolve against `linecache`.
405+
try:
406+
result = ClassStub(value, filename=self._cell_filename())
407+
except (TypeError, OSError):
408+
result = value
320409
elif isinstance(value, UIElement):
321410
result = UIElementStub(value)
322411
elif isinstance(value, tuple):
@@ -397,6 +486,20 @@ def _convert_to_stub_if_needed(
397486

398487
return result
399488

489+
@staticmethod
490+
def _cell_filename() -> str | None:
491+
"""Source filename of the executing cell, for sourcing classes that
492+
lack a method code object. `None` when there is no runtime context
493+
(e.g. script-mode / direct API use)."""
494+
try:
495+
from marimo._ast.compiler import get_filename
496+
497+
context = get_context().execution_context
498+
assert context is not None
499+
return get_filename(context.cell_id)
500+
except (ContextNotInitializedError, AssertionError):
501+
return None
502+
400503
def contextual_defs(self) -> dict[tuple[Name, Name], Any]:
401504
"""Uses context to resolve private variable names."""
402505
try:

marimo/_save/hash.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -705,11 +705,13 @@ def serialize_and_dequeue_content_refs(
705705
version = ""
706706
module = None
707707
if self.pin_modules:
708-
module = sys.modules[imports[ref].module]
709-
version = getattr(module, "__version__", "")
708+
# `.get`: a cached module def restored as a missing-
709+
# module placeholder is in scope but not sys.modules.
710+
module = sys.modules.get(imports[ref].module)
711+
version = getattr(module, "__version__", "") or ""
710712
if not version:
711-
module = sys.modules[imports[ref].namespace]
712-
version = getattr(module, "__version__", "")
713+
module = sys.modules.get(imports[ref].namespace)
714+
version = getattr(module, "__version__", "") or ""
713715

714716
content_serialization[ref] = type_sign(
715717
bytes(f"module:{ref}:{version}", "utf-8"), "module"
@@ -1078,6 +1080,7 @@ def cache_attempt_from_hash(
10781080
hasher.defs,
10791081
hasher.key,
10801082
hasher.stateful_refs,
1083+
glbls=scope,
10811084
)
10821085

10831086

@@ -1166,4 +1169,5 @@ def content_cache_attempt_from_base(
11661169
hasher.defs,
11671170
hasher.key,
11681171
stateful_refs,
1172+
glbls=scope,
11691173
)

0 commit comments

Comments
 (0)