diff --git a/mlflow_client.py b/mlflow_client.py index de6b2dc..2870465 100644 --- a/mlflow_client.py +++ b/mlflow_client.py @@ -62,9 +62,11 @@ def model_parameters(self, name: str) -> Dict[str, Any] | None: pass return flatten_dict(parameters) - def model_versions(self, name: str) -> List[str] | None: + def model_versions(self, name: str) -> List[Dict[str, str]] | None: versions = self.client.search_model_versions(f"name='{name}'") - return [version.version for version in versions] + return [{"version": version.version, + "type": "" if version.tags.get("model_type") is None else version.tags["model_type"]} + for version in versions] def model_metrics(self, name: str) -> Dict[str, Any] | None: run_id = self.client.get_registered_model(name).latest_versions[0].run_id