diff --git a/diabetes_regression/util/model_helper.py b/diabetes_regression/util/model_helper.py index f90237e5..0fd20ef0 100644 --- a/diabetes_regression/util/model_helper.py +++ b/diabetes_regression/util/model_helper.py @@ -8,8 +8,8 @@ def get_current_workspace() -> Workspace: """ - Retrieves and returns the latest model from the workspace - by its name and tag. Will not work when ran locally. + Retrieves and returns the current workspace. + Will not work when ran locally. Parameters: None @@ -30,8 +30,8 @@ def get_model( aml_workspace: Workspace = None ) -> AMLModel: """ - Retrieves and returns the latest model from the workspace - by its name and (optional) tag. + Retrieves and returns a model from the workspace by its name + and (optional) tag. Parameters: aml_workspace (Workspace): aml.core Workspace that the model lives. @@ -40,25 +40,40 @@ def get_model( (optional) tag (str): the tag value & name the model was registered under. Return: - A single aml model from the workspace that matches the name and tag. + A single aml model from the workspace that matches the name and tag, or + None. """ if aml_workspace is None: print("No workspace defined - using current experiment workspace.") aml_workspace = get_current_workspace() - if tag_name is not None and tag_value is not None: + tags = None + if tag_name is not None or tag_value is not None: + # Both a name and value must be specified to use tags. + if tag_name is None or tag_value is None: + raise ValueError( + "model_tag_name and model_tag_value should both be supplied" + + "or excluded" # NOQA: E501 + ) + tags = [[tag_name, tag_value]] + + model = None + if model_version is not None: + # TODO(tcare): Finding a specific version currently expects exceptions + # to propagate in the case we can't find the model. This call may + # result in a WebserviceException that may or may not be due to the + # model not existing. model = AMLModel( aml_workspace, name=model_name, version=model_version, - tags=[[tag_name, tag_value]]) - elif (tag_name is None and tag_value is not None) or ( - tag_value is None and tag_name is not None - ): - raise ValueError( - "model_tag_name and model_tag_value should both be supplied" - + "or excluded" # NOQA: E501 - ) + tags=tags) else: - model = AMLModel(aml_workspace, name=model_name, version=model_version) # NOQA: E501 + models = AMLModel.list( + aml_workspace, name=model_name, tags=tags, latest=True) + if len(models) == 1: + model = models[0] + elif len(models) > 1: + raise Exception("Expected only one model") + return model