Skip to content

Commit

Permalink
Merge pull request #649 from FabienArcellier/620-state-schema-does-no…
Browse files Browse the repository at this point in the history
…t-accept-generic-dict-types

fix: state schema should accept generic dict types as typing.Dict and TypedDict - WF-115
  • Loading branch information
ramedina86 authored Nov 29, 2024
2 parents 129dd54 + 030eb20 commit 985e6cf
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 4 deletions.
30 changes: 29 additions & 1 deletion src/writer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import secrets
import time
import traceback
import typing
import urllib.request
from contextvars import ContextVar
from multiprocessing.process import BaseProcess
Expand Down Expand Up @@ -677,6 +678,7 @@ def bind_annotations_to_state_proxy(cls, klass):
proxy = DictPropertyProxy("_state_proxy", key)
setattr(klass, key, proxy)


class State(metaclass=StateMeta):

def __init__(self, raw_state: Optional[Dict[str, Any]] = None):
Expand Down Expand Up @@ -776,7 +778,7 @@ def _set_state_item(self, key: str, value: Any):
"""
annotations = get_annotations(self)
expected_type = annotations.get(key, None)
expect_dict = expected_type is not None and inspect.isclass(expected_type) and issubclass(expected_type, dict)
expect_dict = _type_match_dict(expected_type)
if isinstance(value, dict) and not expect_dict:
"""
When the value is a dictionary and the attribute does not explicitly
Expand Down Expand Up @@ -2078,6 +2080,32 @@ def _deserialize_bigint_format(payload: Optional[Union[dict, list]]):

return payload


def _type_match_dict(expected_type: Type):
"""
Checks if the expected type expect a dictionary type
>>> _type_match_dict(dict) # True
>>> _type_match_dict(int) # False
>>> _type_match_dict(Dict[str, Any]) # True
>>> class SpecifcDict(TypedDict):
>>> a: str
>>> b: str
>>>
>>> _type_match_dict(SpecifcDict) # True
"""
if expected_type is not None and \
inspect.isclass(expected_type) and \
issubclass(expected_type, dict):
return True

if typing.get_origin(expected_type) == dict:
return True

return False


def unescape_bigint_matching_string(string: str) -> str:
"""
Unescapes a string
Expand Down
80 changes: 77 additions & 3 deletions tests/backend/test_core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import math
import typing
import unittest
import urllib
from typing import Dict
from typing import Any, Dict

import altair
import numpy as np
Expand Down Expand Up @@ -268,7 +269,7 @@ def to_dict(self):

class TestState:

def test_set_dictionary_in_a_state_should_transform_it_in_state_proxy_and_trigger_mutation(self):
def test_state_shema_set_dictionary_in_a_state_should_transform_it_in_state_proxy_and_trigger_mutation(self):
"""
Tests that writing a dictionary in a State without schema is transformed into a StateProxy and
triggers mutations to update the interface
Expand All @@ -287,7 +288,7 @@ def test_set_dictionary_in_a_state_should_transform_it_in_state_proxy_and_trigge
r"+new\.state\.with\.dots.test": "test"
}

def test_set_dictionary_in_a_state_with_schema_should_transform_it_in_state_proxy_and_trigger_mutation(self):
def test_state_shema_set_dictionary_in_a_state_with_schema_should_keep_as_dict(self):
class SimpleSchema(State):
app: dict

Expand Down Expand Up @@ -601,6 +602,79 @@ def cumulative_sum(state):
# Assert
assert initial_state['total'] == 4

def test_state_shema_should_accept_Dict_generic_typing_for_dict(self):
"""
A schema must accept a generic typed dictionary for a dictionary
and manipulate it as a dictionary, not as a StateProxy.
"""
with writer_fixtures.new_app_context():
# Assign
class MyState(wf.WriterState):
counter: int
record: Dict[str, Any]

# Acts
initial_state = wf.init_state({
"counter": 0,
"record": {}
}, schema=MyState)

# Assert
assert isinstance(initial_state['counter'], int)
assert isinstance(initial_state['record'], dict)
assert initial_state['record'] == {}

def test_state_shema_should_accept_DictTyped_typing_for_dict(self):
"""
A schema must accept a dictionary typed for a dictionary
and manipulate it as a dictionary, not as a StateProxy.
"""
class SpecificDictTyped(typing.TypedDict):
a: str
b: str

with writer_fixtures.new_app_context():
# Assign
class MyState(wf.WriterState):
counter: int
record: SpecificDictTyped

# Acts
initial_state = wf.init_state({
"counter": 0,
"record": {}
}, schema=MyState)

# Assert
assert isinstance(initial_state['counter'], int)
assert isinstance(initial_state['record'], dict)
assert initial_state['record'] == {}

def test_state_shema_should_support_convert_dict_into_state_by_default(self):
"""
A schema must accept a dictionary typed for a dictionary
and manipulate it as a dictionary, not as a StateProxy.
"""
class Substate(State):
a: str
b: str

with writer_fixtures.new_app_context():
# Assign
class MyState(wf.WriterState):
counter: int
record: Substate

# Acts
initial_state = wf.init_state({
"counter": 0,
"record": {}
}, schema=MyState)

# Assert
assert isinstance(initial_state['record'], Substate)


class TestWriterState:

# Initialised manually
Expand Down

0 comments on commit 985e6cf

Please sign in to comment.