Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ability to add fields to StateDataClass #47

Closed
wants to merge 8 commits into from
143 changes: 142 additions & 1 deletion src/autora/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,23 @@
import logging
import warnings
from collections import UserDict
from dataclasses import dataclass, field, fields, is_dataclass, replace
from dataclasses import (
asdict,
dataclass,
field,
fields,
is_dataclass,
make_dataclass,
replace,
)
from enum import Enum
from functools import singledispatch, wraps
from typing import (
Callable,
Dict,
Generic,
List,
Literal,
Mapping,
Optional,
Protocol,
Expand All @@ -34,6 +43,8 @@
T = TypeVar("T")
C = TypeVar("C", covariant=True)

MISSING = ...


class DeltaAddable(Protocol[C]):
"""A class which a Delta or other Mapping can be added to, returning the same class"""
Expand Down Expand Up @@ -697,6 +708,136 @@ def update(self, **kwargs):
"""
return self + Delta(**kwargs)

def add_field(
self,
name: str,
value=MISSING,
type_=MISSING,
default=MISSING,
default_factory=MISSING,
delta: Literal["replace", "extend", "append"] = "replace",
converter: Optional[Callable] = None,
):
"""
Return a new copy of the dataclass with an additional field

Start with a StateDataClass; here is an empty one:
>>> s = StateDataClass()
>>> s
StateDataClass()

You can add a field with a new value:
>>> t = s.add_field("b", value=1)
>>> t
StateDataClass(b=1)

The original State is unchanged:
>>> s
StateDataClass()

You can add a field with a default value which is None:
>>> s.add_field("a", default=None)
StateDataClass(a=None)

... or a field with a default_factory (like a list):
>>> s.add_field("l", default_factory=list)
StateDataClass(l=[])

... but not both:
>>> s.add_field("l", default_factory=list, default=None)
Traceback (most recent call last):
...
ValueError: cannot specify both default and default_factory


A field with a default value can also have a different value in the instantiation:
>>> s.add_field("c", default=0, value=1)
StateDataClass(c=1)

... and here with a default_factory:
>>> s.add_field("d", default_factory=list, value=[1, 2, 3])
StateDataClass(d=[1, 2, 3])

You can specify the "type" of the field:
>>> s.add_field("e", type_=List[int], value=[1, 2, 3])
StateDataClass(e=[1, 2, 3])

... but this is only for documentation and will throw no error if the value is wrong
>>> s.add_field("f", type_=int, value="a string, not an int")
StateDataClass(f='a string, not an int')

By default, fields are replaced if modified by Deltas:
>>> s.add_field("h", value=1) + Delta(h=2)
StateDataClass(h=2)

(This is the same as:
>>> s.add_field("h", value=1, delta="replace") + Delta(h=2)
StateDataClass(h=2)

You can specify a different delta type:
>>> u = s.add_field("g", type_=List[int], default_factory=list, delta="extend")
>>> u = u + Delta(g=[3]) + Delta(g=[4])
>>> u
StateDataClass(g=[3, 4])

>>> v = u.add_field("h", type_=int, default=None, delta="replace")
>>> v
StateDataClass(g=[3, 4], h=None)

>>> v + Delta(h=1) + Delta(h=2)
StateDataClass(g=[3, 4], h=2)

>>> w = v.add_field("i", type_=List[int], default_factory=list, delta="append")
>>> w + Delta(i=3) + Delta(i=9) + Delta(i=27)
StateDataClass(g=[3, 4], h=None, i=[3, 9, 27])

You can specify a converter:
>>> x = s.add_field("df", type_=pd.DataFrame, default=None, converter=pd.DataFrame)
>>> x + Delta(df = {"a": [1, 2, 3], "b": list("abc")})
StateDataClass(df= a b
0 1 a
1 2 b
2 3 c)

"""
_field_kwargs = {
k: v
for k, v in dict(default=default, default_factory=default_factory).items()
if v is not MISSING
}

_field_kwargs["metadata"] = {
k: v
for k, v in dict(delta=delta, converter=converter).items()
if v is not None
}

_field = field(**_field_kwargs)
_dataclass_params = {
key: getattr(getattr(self, "__dataclass_params__"), key)
for key in ("frozen", "init", "repr", "eq", "order", "unsafe_hash")
}

new_class = make_dataclass(
cls_name=self.__class__.__name__ + str(hash((name, type_, _field))),
fields=[(name, type_, _field)],
bases=(self.__class__,),
**_dataclass_params,
)

if value is not MISSING:
new_value = {name: value}
else:
new_value = {}
if not (default != MISSING or default_factory != MISSING):
raise ValueError(
"`value` or `default` or `default_factory` must be specified"
)

new = new_class(**new_value, **asdict(self))

return new


def _get_value(f, other: Union[Delta, Mapping]):
"""
Expand Down
34 changes: 0 additions & 34 deletions tests/README.md

This file was deleted.

128 changes: 128 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import pickle
from collections import namedtuple
from typing import Any, List

import dill
import pandas as pd
from hypothesis import given
from hypothesis import strategies as st

from autora.state import StateDataClass


@given(st.sampled_from([StateDataClass()]))
def test_statedict_serialize_deserialize(s):
s_dumped = pickle.dumps(s)
s_loaded = pickle.loads(s_dumped)
assert s_loaded == s


REPLACE = "replace"
EXTEND = "extend"
APPEND = "append"


@st.composite
def type_delta_strategy(draw):
outer_type = {list: [REPLACE, EXTEND, APPEND]}

allowable_combinations = {
bool: [REPLACE],
int: [REPLACE],
str: [REPLACE],
dict: [REPLACE, EXTEND],
list: [REPLACE, EXTEND, APPEND],
pd.DataFrame: [REPLACE, EXTEND],
Any: [REPLACE],
}

st.sampled_from([(None, REPLACE), (list, REPLACE), (list, EXTEND), (list, APPEND)])


_FIELD_DEF = namedtuple("_FIELD_DEF", ["type_", "delta"])


@st.composite
def field_strategy(draw):
outer_type_strategy = st.sampled_from([None, List])
core_type_strategy = st.sampled_from(
[
bool,
int,
str,
# dict, # TODO: add support for this
# pd.DataFrame, # TODO: add support for this
# Any # TODO: add support for this
]
)
value_strategies = {
bool: st.booleans,
int: st.integers,
str: st.text,
List: st.lists,
}
default_factory_strategies = {List: list}

delta_strategy = st.sampled_from([REPLACE, EXTEND, APPEND])
outer_type, inner_type, delta = draw(
st.tuples(outer_type_strategy, core_type_strategy, delta_strategy)
)

if outer_type is None:
merged_type = inner_type
field_def = {"type_": merged_type, "delta": delta}
value_field = draw(st.sampled_from(["value", "default"]))
value_strategy = value_strategies[inner_type]
field_def[value_field] = draw(value_strategy())

elif outer_type is not None:
merged_type = outer_type[inner_type]
field_def = {"type_": merged_type, "delta": delta}

value_field = draw(st.sampled_from(["value", "default", "default_factory"]))
if value_field == "default_factory":
field_def[value_field] = default_factory_strategies[outer_type]
else:
value_strategy = value_strategies[outer_type](
value_strategies[inner_type]()
)
field_def[value_field] = draw(value_strategy)

return field_def


@st.composite
def state_object_strategy(draw):
variable_name_strategy = st.from_regex("\A[_A-Za-z][_A-Za-z0-9]*\Z")

field_names = draw(st.lists(variable_name_strategy, unique=True))
field_names_defs = {n: draw(field_strategy()) for n in field_names}

d = StateDataClass()
for name, f in field_names_defs.items():
d = d.add_field(name=name, **f)

return d


@given(state_object_strategy())
def test_statedict_serialize_deserialize_data(s):
s

s_dumped = pickle.dumps(s)
s_loaded = pickle.loads(s_dumped)
assert s_loaded == s


@given(st.data())
def test_draw_sequentially(data):
x = data.draw(state_object_strategy())


@given(state_object_strategy())
def test_statedict_serialize_deserialize_data_dill(s):
s

s_dumped = dill.dumps(s)
s_loaded = dill.loads(s_dumped)
assert s_loaded == s
Loading