Skip to content

Commit

Permalink
tidy: utilize state / ui registry over requiring correct scope
Browse files Browse the repository at this point in the history
  • Loading branch information
dmadisetti committed Sep 5, 2024
1 parent ecd57ca commit 9dc0ae3
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 123 deletions.
24 changes: 19 additions & 5 deletions marimo/_plugins/ui/_core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,33 @@ def _find_bindings_in_namespace(
bindings.add(name)
return bindings

def _register_bindings(self, object_id: UIElementId) -> None:
def _register_bindings(
self, object_id: UIElementId, glbls: Optional[dict[str, Any]] = None
) -> None:
from marimo._runtime.context.kernel_context import KernelRuntimeContext

ctx = get_context()
if isinstance(ctx, KernelRuntimeContext):
if isinstance(ctx, KernelRuntimeContext) or glbls is not None:
if glbls is None:
glbls = ctx.globals
self._bindings[object_id] = self._find_bindings_in_namespace(
object_id, ctx.globals
object_id, glbls
)

def lookup(self, name: str) -> Optional[UIElementId]:
def register_scope(
self, glbls: dict[str, Any], defs: Optional[set[str]] = None
) -> None:
if defs is None:
defs = set(glbls.keys())
for binding in defs:
lookup = glbls.get(binding, None)
if isinstance(lookup, UIElement):
self._register_bindings(lookup._id, glbls)

def lookup(self, name: str) -> Optional[UIElement[Any, Any]]:
for object_id, bindings in self._bindings.items():
if name in bindings:
return object_id
return self.get_object(object_id)
return None

def get_object(self, object_id: UIElementId) -> UIElement[Any, Any]:
Expand Down
2 changes: 2 additions & 0 deletions marimo/_runtime/context/kernel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,14 @@ def create_kernel_context(
parent: KernelRuntimeContext | None = None,
) -> KernelRuntimeContext:
from marimo._plugins.ui._core.registry import UIElementRegistry
from marimo._runtime.state import StateRegistry
from marimo._runtime.virtual_file import VirtualFileRegistry

return KernelRuntimeContext(
_kernel=kernel,
_app=app,
ui_element_registry=UIElementRegistry(),
state_registry=StateRegistry(),
function_registry=FunctionRegistry(),
cell_lifecycle_registry=CellLifecycleRegistry(),
virtual_file_registry=VirtualFileRegistry(),
Expand Down
5 changes: 3 additions & 2 deletions marimo/_runtime/context/script_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from marimo._runtime.dataflow import DirectedGraph
from marimo._runtime.functions import FunctionRegistry
from marimo._runtime.params import CLIArgs, QueryParams
from marimo._runtime.state import State, StateRegistry

if TYPE_CHECKING:
from marimo._ast.app import InternalApp
from marimo._ast.cell import CellId_t
from marimo._messaging.types import Stream
from marimo._runtime.state import State


@dataclass
Expand All @@ -34,7 +34,7 @@ class ScriptRuntimeContext(RuntimeContext):

def __post_init__(self) -> None:
self._cli_args: CLIArgs | None = None
self._query_params = QueryParams({})
self._query_params = QueryParams({}, _registry=self.state_registry)

@property
def graph(self) -> DirectedGraph:
Expand Down Expand Up @@ -120,6 +120,7 @@ def initialize_script_context(app: InternalApp, stream: Stream) -> None:
runtime_context = ScriptRuntimeContext(
_app=app,
ui_element_registry=UIElementRegistry(),
state_registry=StateRegistry(),
function_registry=FunctionRegistry(),
cell_lifecycle_registry=CellLifecycleRegistry(),
virtual_file_registry=VirtualFileRegistry(),
Expand Down
3 changes: 2 additions & 1 deletion marimo/_runtime/context/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from marimo._output.hypertext import Html
from marimo._plugins.ui._core.registry import UIElementRegistry
from marimo._runtime.params import CLIArgs, QueryParams
from marimo._runtime.state import State
from marimo._runtime.state import State, StateRegistry
from marimo._runtime.virtual_file import VirtualFileRegistry


Expand Down Expand Up @@ -64,6 +64,7 @@ class ExecutionContext:
@dataclass
class RuntimeContext(abc.ABC):
ui_element_registry: UIElementRegistry
state_registry: StateRegistry
function_registry: FunctionRegistry
cell_lifecycle_registry: CellLifecycleRegistry
virtual_file_registry: VirtualFileRegistry
Expand Down
5 changes: 3 additions & 2 deletions marimo/_runtime/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
SerializedCLIArgs,
SerializedQueryParams,
)
from marimo._runtime.state import State
from marimo._runtime.state import State, StateRegistry


@mddoc
Expand All @@ -31,8 +31,9 @@ def __init__(
self,
params: Dict[str, Union[str, List[str]]],
stream: Optional[Stream] = None,
_registry: Optional[StateRegistry] = None,
):
super().__init__(params)
super().__init__(params, _registry=_registry)
self._params = params
self._stream = stream

Expand Down
8 changes: 4 additions & 4 deletions marimo/_runtime/runner/hooks_post_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
from marimo._messaging.tracebacks import write_traceback
from marimo._output import formatting
from marimo._plugins.ui._core.ui_element import UIElement
from marimo._runtime.context.types import get_global_context
from marimo._runtime.context.types import get_context, get_global_context
from marimo._runtime.control_flow import MarimoInterrupt, MarimoStopError
from marimo._runtime.runner import cell_runner
from marimo._runtime.state import StateRegistry
from marimo._tracer import kernel_tracer
from marimo._utils.flatten import contains_instance

Expand Down Expand Up @@ -176,8 +175,9 @@ def _store_state_reference(
run_result: cell_runner.RunResult,
) -> None:
del run_result
StateRegistry.register_scope(cell.defs, runner.glbls)
StateRegistry.retain_active_states(set(runner.glbls.keys()))
ctx = get_context()
ctx.state_registry.register_scope(runner.glbls, defs=cell.defs)
ctx.state_registry.retain_active_states(set(runner.glbls.keys()))


@kernel_tracer.start_as_current_span("broadcast_outputs")
Expand Down
91 changes: 50 additions & 41 deletions marimo/_runtime/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,77 +23,86 @@ class StateRegistry:
_states: dict[str, StateItem[Any]] = {}
_inv_states: dict[int, set[str]] = {}

@staticmethod
def register(name: str, state: State[T]) -> None:
if id(state) in StateRegistry._inv_states:
ref = next(iter(StateRegistry._inv_states[id(state)]))
if StateRegistry._states[ref].id != id(state):
for ref in StateRegistry._inv_states[id(state)]:
del StateRegistry._states[ref]
StateRegistry._inv_states[id(state)].clear()
def register(self, state: State[T], name: Optional[str] = None) -> None:
if name is None:
name = str(uuid4())
if id(state) in self._inv_states:
ref = next(iter(self._inv_states[id(state)]))
if self._states[ref].id != id(state):
for ref in self._inv_states[id(state)]:
del self._states[ref]
self._inv_states[id(state)].clear()
state_item = StateItem(id(state), weakref.ref(state))
StateRegistry._states[name] = state_item
id_to_ref = StateRegistry._inv_states.get(id(state), set())
self._states[name] = state_item
id_to_ref = self._inv_states.get(id(state), set())
id_to_ref.add(name)
StateRegistry._inv_states[id(state)] = id_to_ref
finalizer = weakref.finalize(
state, StateRegistry._delete, name, state_item
)
self._inv_states[id(state)] = id_to_ref
finalizer = weakref.finalize(state, self._delete, name, state_item)
# No need to clean up the registry at program teardown
finalizer.atexit = False

@staticmethod
def register_scope(defs: set[str], glbls: dict[str, Any]) -> None:
def register_scope(
self, glbls: dict[str, Any], defs: Optional[set[str]] = None
) -> None:
if defs is None:
defs = set(glbls.keys())
for variable in defs:
lookup = glbls.get(variable, None)
if isinstance(lookup, State):
StateRegistry.register(variable, lookup)
self.register(lookup, variable)

@staticmethod
def _delete(name: str, state_item: StateItem[T]) -> None:
StateRegistry._states.pop(name, None)
StateRegistry._inv_states.pop(state_item.id, None)
def _delete(self, name: str, state_item: StateItem[T]) -> None:
self._states.pop(name, None)
self._inv_states.pop(state_item.id, None)

@staticmethod
def retain_active_states(active_variables: set[str]) -> None:
def retain_active_states(self, active_variables: set[str]) -> None:
"""Retains only the active states in the registry."""
# Remove all non-active states by name
active_state_ids = set()
for state_name in list(StateRegistry._states.keys()):
for state_name in list(self._states.keys()):
if state_name not in active_variables:
StateRegistry._inv_states.pop(
id(StateRegistry._states[state_name]), None
)
del StateRegistry._states[state_name]
self._inv_states.pop(id(self._states[state_name]), None)
del self._states[state_name]
else:
active_state_ids.add(id(StateRegistry._states[state_name]))
active_state_ids.add(id(self._states[state_name]))

# Remove all non-active states by id
for state_id in list(StateRegistry._inv_states.keys()):
for state_id in list(self._inv_states.keys()):
if state_id not in active_state_ids:
del StateRegistry._inv_states[state_id]
del self._inv_states[state_id]

@staticmethod
def lookup(name: str) -> Optional[State[T]]:
if name in StateRegistry._states:
return StateRegistry._states[name].ref()
def lookup(self, name: str) -> Optional[State[T]]:
if name in self._states:
return self._states[name].ref()
return None

@staticmethod
def get_references(state: State[T]) -> set[str]:
if id(state) in StateRegistry._inv_states:
return StateRegistry._inv_states[id(state)]
def bound_names(self, state: State[T]) -> set[str]:
if id(state) in self._inv_states:
return self._inv_states[id(state)]
return set()


class State(Generic[T]):
"""Mutable reactive state"""

def __init__(self, value: T, allow_self_loops: bool = False) -> None:
def __init__(
self,
value: T,
allow_self_loops: bool = False,
_registry: Optional[StateRegistry] = None,
) -> None:
self._value = value
self.allow_self_loops = allow_self_loops
self._set_value = SetFunctor(self)
StateRegistry.register(str(uuid4()), self)

try:
if _registry is None:
_registry = get_context().state_registry
_registry.register(self)
except ContextNotInitializedError:
# Registration may be picked up later, but there is nothing to do
# at this point.
pass

def __call__(self) -> T:
return self._value
Expand Down
33 changes: 19 additions & 14 deletions marimo/_save/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

from marimo._ast.visitor import ScopedVisitor
from marimo._dependencies.dependencies import DependencyManager
from marimo._plugins.ui._core.ui_element import UIElement
from marimo._runtime.context import get_context
from marimo._runtime.primitives import (
FN_CACHE_TYPE,
is_data_primitive,
is_primitive,
is_pure_function,
)
from marimo._runtime.state import SetFunctor, State, StateRegistry
from marimo._runtime.state import SetFunctor, State
from marimo._save.cache import Cache, CacheType
from marimo._utils.variables import (
get_cell_from_local,
Expand All @@ -30,6 +30,7 @@
from types import CodeType

from marimo._ast.cell import CellId_t, CellImpl
from marimo._runtime.context.types import RuntimeContext
from marimo._runtime.dataflow import DirectedGraph
from marimo._save.loaders import Loader

Expand Down Expand Up @@ -203,6 +204,7 @@ def normalize_and_extract_ref_state(
refs: set[str],
defs: dict[str, Any],
cell_id: CellId_t,
ctx: RuntimeContext,
) -> set[str]:
stateful_refs = set()

Expand All @@ -229,15 +231,15 @@ def normalize_and_extract_ref_state(

# State relevant to the context, should be dependent on it's value- not
# the object.
if isinstance(defs[ref], State):
value = defs[ref]()
for state_name in StateRegistry.get_references(defs[ref]):
defs[state_name] = value
value: Optional[State[Any]]
if value := ctx.state_registry.lookup(ref):
for state_name in ctx.state_registry.bound_names(value):
defs[state_name] = value()

# Likewise, UI objects should be dependent on their value.
if isinstance(defs[ref], UIElement):
ui = defs[ref]
defs[ref] = ui.value
if (ui := ctx.ui_element_registry.lookup(ref)) is not None:
for ui_name in ctx.ui_element_registry.bound_names(ui._id):
defs[ui_name] = ui.value
# If the UI is directly consumed, then hold on to the reference
# for proper cache update.
stateful_refs.add(ref)
Expand Down Expand Up @@ -298,9 +300,12 @@ def cache_attempt_from_hash(
# Get stateful registers
# This is typically done in post execution hook, but it will not be called
# in script mode.
StateRegistry.register_scope(set(defs.keys()), defs)
ctx = get_context()
ctx.ui_element_registry.register_scope(defs)
ctx.state_registry.register_scope(defs)

stateful_refs = normalize_and_extract_ref_state(
visitor.refs, refs, defs, cell_id
visitor.refs, refs, defs, cell_id, ctx
)

# usedforsecurity=False used to satisfy some static analysis tools.
Expand All @@ -321,15 +326,15 @@ def cache_attempt_from_hash(
refs |= set(
filter(
lambda ref: (
StateRegistry.lookup(ref)
or isinstance(defs[ref], UIElement)
ctx.state_registry.lookup(ref)
or ctx.ui_element_registry.lookup(ref)
),
transitive_state_refs,
)
)
# Need to run extract again for the expanded ref set.
stateful_refs |= normalize_and_extract_ref_state(
visitor.refs, refs, defs, cell_id
visitor.refs, refs, defs, cell_id, ctx
)

# Attempt content hash
Expand Down
Loading

0 comments on commit 9dc0ae3

Please sign in to comment.