diff --git a/src/autora/state.py b/src/autora/state.py index 5095b6bc..75e11838 100644 --- a/src/autora/state.py +++ b/src/autora/state.py @@ -270,6 +270,53 @@ def update(self, **kwargs): """ return self + Delta(**kwargs) + def copy(self): + """ + Return a deepcopy of the State + Examples: + >>> @dataclass(frozen=True) + ... class DfState(State): + ... q: pd.DataFrame = field(default_factory=pd.DataFrame, + ... metadata={"delta": "replace", + ... "converter": pd.DataFrame}) + >>> data = pd.DataFrame({'x': [1, 2, 3]}) + >>> s_1 = DfState(q=data) + >>> s_replace = replace(s_1) + >>> s_copy = s_1.copy() + + The build in replace method doesn't create a deepcopy: + >>> s_1.q is s_replace.q + True + >>> s_1.q['y'] = [1,2,3] + >>> s_replace.q + x y + 0 1 1 + 1 2 2 + 2 3 3 + + But this copy method does: + >>> s_1.q is s_copy.q + False + >>> s_copy.q + x + 0 1 + 1 2 + 2 3 + + + """ + # Create a dictionary to hold the field copies + field_copies = {} + + # Iterate over all fields of the class + for _field in fields(self): + value = getattr(self, _field.name) + # Use deepcopy to ensure that mutable fields are also copied + field_copies[_field.name] = copy.deepcopy(value) + + # Use replace with **field_copies to create a new instance of the same class + return replace(self, **field_copies) + def _get_value(f, other: Union[Delta, Mapping]): """