44import abc
55import inspect
66import re
7+ import textwrap
78from collections import namedtuple
89from dataclasses import dataclass , field
910from typing import TYPE_CHECKING , Any , Literal , get_args
1415from marimo ._runtime .state import SetFunctor
1516from marimo ._save .stubs import (
1617 CUSTOM_STUBS ,
18+ ClassStub ,
1719 CustomStub ,
1820 FunctionStub ,
1921 ModuleStub ,
2022 ReferenceStub ,
2123 UIElementStub ,
2224 maybe_register_stub ,
2325)
26+ from marimo ._save .stubs .lazy_stub import UnhashableStub
2427
2528# Many assertions are for typing and should always pass. This message is a
2629# catch all to motive users to report if something does fail.
@@ -114,6 +117,24 @@ def dispose(self, context: RuntimeContext, deletion: bool) -> bool: # noqa: ARG
114117)
115118
116119
120+ def _source_refs (code : str ) -> set [Name ]:
121+ """Free references of a captured class/function source block.
122+
123+ Reuses the cell `ScopedVisitor` so the notion of "what this def needs"
124+ matches marimo's dataflow rather than a bespoke parser.
125+ """
126+ from marimo ._ast .parse import ast_parse
127+ from marimo ._ast .visitor import ScopedVisitor
128+
129+ try :
130+ tree = ast_parse (textwrap .dedent (code ))
131+ except SyntaxError :
132+ return set ()
133+ visitor = ScopedVisitor ("cache_restore" )
134+ visitor .visit (tree )
135+ return set (visitor .refs )
136+
137+
117138# BaseException because "raise _ as e" is utilized.
118139class CacheException (BaseException ):
119140 pass
@@ -136,7 +157,7 @@ class Cache:
136157 def restore (self , scope : dict [str , Any ]) -> None :
137158 """Restores values from cache, into scope."""
138159 memo : dict [int , Any ] = {} # Track processed objects to handle cycles
139- for var , lookup in self .contextual_defs ():
160+ for var , lookup in self ._restore_order ( self . contextual_defs () ):
140161 value = self .defs .get (var , None )
141162 scope [lookup ] = self ._restore_from_stub_if_needed (
142163 value , scope , memo
@@ -167,6 +188,44 @@ def restore(self, scope: dict[str, Any]) -> None:
167188 f"({ type (ref )} :{ ref } )."
168189 )
169190
191+ def _restore_deps (self , value : Any ) -> set [Name ]:
192+ """Cross-def names *value* needs before it can be restored.
193+
194+ - A re-exec'd class/function needs the defs it references at
195+ definition time (bases, decorators, class-body calls).
196+ - A pickled instance of a cell-defined class needs that class
197+ materialized first (tagged via `requires`).
198+ """
199+ if isinstance (value , (ClassStub , FunctionStub )):
200+ return _source_refs (value .code )
201+ requires = getattr (value , "requires" , "" )
202+ return {requires } if requires else set ()
203+
204+ def _restore_order (
205+ self , contextual_defs : dict [tuple [Name , Name ], Any ]
206+ ) -> list [tuple [Name , Name ]]:
207+ """Order `(var, lookup)` pairs so each def's cross-def dependencies
208+ restore first (depth-first; tolerant of cycles via `seen`)."""
209+ lookups = {var : lookup for var , lookup in contextual_defs }
210+ deps = {
211+ var : self ._restore_deps (self .defs .get (var )) & lookups .keys ()
212+ for var in lookups
213+ }
214+ order : list [tuple [Name , Name ]] = []
215+ seen : set [Name ] = set ()
216+
217+ def visit (var : Name ) -> None :
218+ if var in seen :
219+ return
220+ seen .add (var )
221+ for dep in deps [var ]:
222+ visit (dep )
223+ order .append ((var , lookups [var ]))
224+
225+ for var in lookups :
226+ visit (var )
227+ return order
228+
170229 def _restore_from_stub_if_needed (
171230 self ,
172231 value : Any ,
@@ -223,6 +282,16 @@ def _restore_from_stub_if_needed(
223282 result = value
224283 elif isinstance (value , ReferenceStub ):
225284 result = value .load (scope )
285+ elif isinstance (value , ClassStub ):
286+ # Re-exec the captured source into the cell namespace so the
287+ # name rebinds to a live class (not the stub). Must run before
288+ # any pickle blob referencing the class deserializes, so the
289+ # class is resolvable as `__main__.<name>` in `scope`.
290+ result = value .load (scope )
291+ elif isinstance (value , UnhashableStub ):
292+ # Marker for a def whose value couldn't be serialized. Place it
293+ # in scope as-is.
294+ result = value
226295 elif isinstance (value , CustomStub ):
227296 # CustomStub is a placeholder for a custom type, which cannot be
228297 # restored directly.
@@ -315,8 +384,28 @@ def _convert_to_stub_if_needed(
315384
316385 if inspect .ismodule (value ):
317386 result = ModuleStub (value )
318- elif inspect .isfunction (value ):
387+ elif (
388+ inspect .isfunction (value )
389+ and value .__name__ != "<lambda>"
390+ and getattr (value , "__module__" , "__main__" ) == "__main__"
391+ ):
392+ # NB. Lambdas can't round-trip via FunctionStub: inspect.getsource
393+ # returns the line *containing* the lambda (e.g. "return model,
394+ # lambda inp: model(inp)"), which fails to compile as a module.
319395 result = FunctionStub (value )
396+ elif (
397+ inspect .isclass (value )
398+ and getattr (value , "__module__" , None ) == "__main__"
399+ ):
400+ # Attempt to capture classes by source so the loader can rebuild
401+ # them in the cell namespace before pickle blobs that reference them
402+ # deserialize. Pass the executing cell's source filename as a hint
403+ # so attribute-only classes (no method code object to read a
404+ # filename from) still resolve against `linecache`.
405+ try :
406+ result = ClassStub (value , filename = self ._cell_filename ())
407+ except (TypeError , OSError ):
408+ result = value
320409 elif isinstance (value , UIElement ):
321410 result = UIElementStub (value )
322411 elif isinstance (value , tuple ):
@@ -397,6 +486,20 @@ def _convert_to_stub_if_needed(
397486
398487 return result
399488
489+ @staticmethod
490+ def _cell_filename () -> str | None :
491+ """Source filename of the executing cell, for sourcing classes that
492+ lack a method code object. `None` when there is no runtime context
493+ (e.g. script-mode / direct API use)."""
494+ try :
495+ from marimo ._ast .compiler import get_filename
496+
497+ context = get_context ().execution_context
498+ assert context is not None
499+ return get_filename (context .cell_id )
500+ except (ContextNotInitializedError , AssertionError ):
501+ return None
502+
400503 def contextual_defs (self ) -> dict [tuple [Name , Name ], Any ]:
401504 """Uses context to resolve private variable names."""
402505 try :
0 commit comments