diff --git a/main.py b/main.py index 22304a1..aadd387 100644 --- a/main.py +++ b/main.py @@ -5,9 +5,11 @@ from fastapi.middleware.cors import CORSMiddleware from mlflow_client import Client from pydantic import BaseModel -from schemas import Models, Parameters, Metrics, Dataset, Images +from schemas import Models, Parameters, Metrics, Dataset, Images, Versions from dotenv import load_dotenv +load_dotenv() + app = FastAPI() client = Client() @@ -79,12 +81,12 @@ async def model_images(name: str): @app.get("/model/versions", tags=["Endpoints that gets all the versions of a specified register model"], - response_model=Metrics) + response_model=Versions) async def model_metrics(name: str): - metrics = client.model_metrics(name) - if metrics is None: - return JSONResponse("Error getting the metrics!", status_code=500) - return JSONResponse(metrics, status_code=200) + versions = client.model_versions(name) + if versions is None: + return JSONResponse("Error getting the model versions!", status_code=500) + return JSONResponse(versions, status_code=200) @app.post("/model/predict") diff --git a/mlflow_client.py b/mlflow_client.py index 0153172..de6b2dc 100644 --- a/mlflow_client.py +++ b/mlflow_client.py @@ -4,7 +4,7 @@ import mlflow import base64 import pandas as pd -from typing import Any, Dict +from typing import Any, Dict, List from datetime import datetime from json import JSONDecodeError from io import StringIO, BytesIO @@ -62,9 +62,9 @@ def model_parameters(self, name: str) -> Dict[str, Any] | None: pass return flatten_dict(parameters) - def model_versions(self, name: str) -> Dict[str, Any] | None: - versions = self.client.get_registered_model(name).latest_versions - return versions + def model_versions(self, name: str) -> List[str] | None: + versions = self.client.search_model_versions(f"name='{name}'") + return [version.version for version in versions] def model_metrics(self, name: str) -> Dict[str, Any] | None: run_id = self.client.get_registered_model(name).latest_versions[0].run_id diff --git a/schemas.py b/schemas.py index ecfaee9..b237a54 100644 --- a/schemas.py +++ b/schemas.py @@ -1,6 +1,5 @@ -from fastapi import File from pydantic import BaseModel -from typing import Dict, Any +from typing import Dict, Any, List class Models(BaseModel): @@ -17,7 +16,7 @@ class Metrics(BaseModel): class Versions(BaseModel): - versions: Dict[str, Any] + versions: List[str] class Dataset(BaseModel):