Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelBuilder.load versatility improvements #210

Merged
merged 4 commits into from
Jul 11, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def sample_model(self, **kwargs):
idata.extend(pm.sample_prior_predictive())
idata.extend(pm.sample_posterior_predictive(idata))

self.set_idata_attrs(idata)
idata = self.set_idata_attrs(idata)
return idata

def set_idata_attrs(self, idata=None):
Expand Down Expand Up @@ -338,6 +338,10 @@ def set_idata_attrs(self, idata=None):
idata.attrs["version"] = self.version
idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
idata.attrs["model_config"] = json.dumps(self._serializable_model_config)
# Only classes with non-dataset parameters will implement save_input_params
if hasattr(self, "_save_input_params"):
self._save_input_params(idata)
return idata

def save(self, fname: str) -> None:
"""
Expand Down Expand Up @@ -375,6 +379,12 @@ def save(self, fname: str) -> None:
else:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")

def _convert_dims_to_tuple(model_config: Dict) -> Dict:
for key in model_config:
if "dims" in model_config[key] and isinstance(model_config[key]["dims"], list):
model_config[key]["dims"] = tuple(model_config[key]["dims"])
return model_config

@classmethod
def load(cls, fname: str):
"""
Expand Down Expand Up @@ -403,8 +413,10 @@ def load(cls, fname: str):
"""
filepath = Path(str(fname))
idata = az.from_netcdf(filepath)
# needs to be converted, because json.loads was changing tuple to list
model_config = cls._convert_dims_to_tuple(json.loads(idata.attrs["model_config"]))
model = cls(
model_config=json.loads(idata.attrs["model_config"]),
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
model.idata = idata
Expand Down Expand Up @@ -480,6 +492,7 @@ def fit(
combined_data = pd.concat([X_df, y], axis=1)
assert all(combined_data.columns), "All columns must have non-empty names"
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore

return self.idata # type: ignore

def predict(
Expand Down