From ab5ffa9d816cae62d2a63500919e783b5b0efcb7 Mon Sep 17 00:00:00 2001 From: Isabella Basso Date: Thu, 2 May 2024 04:39:50 -0300 Subject: [PATCH] Handle REST status codes (#74) * core: handle 400 and 404 Signed-off-by: Isabella do Amaral * handle conversion errors as 400 Signed-off-by: Isabella do Amaral --------- Signed-off-by: Isabella do Amaral --- .../api_model_registry_service_service.go | 212 ++++++++++++----- pkg/api/error.go | 8 + pkg/core/core.go | 216 +++++++++--------- pkg/core/core_test.go | 48 ++-- 4 files changed, 294 insertions(+), 190 deletions(-) create mode 100644 pkg/api/error.go diff --git a/internal/server/openapi/api_model_registry_service_service.go b/internal/server/openapi/api_model_registry_service_service.go index e5d5b98f..7b3d7700 100644 --- a/internal/server/openapi/api_model_registry_service_service.go +++ b/internal/server/openapi/api_model_registry_service_service.go @@ -11,6 +11,7 @@ package openapi import ( "context" + "errors" "github.com/kubeflow/model-registry/internal/apiutils" "github.com/kubeflow/model-registry/internal/converter" @@ -46,15 +47,17 @@ func (s *ModelRegistryServiceAPIService) CreateEnvironmentInferenceService(ctx c func (s *ModelRegistryServiceAPIService) CreateInferenceService(ctx context.Context, inferenceServiceCreate model.InferenceServiceCreate) (ImplResponse, error) { entity, err := s.converter.ConvertInferenceServiceCreate(&inferenceServiceCreate) if err != nil { - return Response(500, model.Error{Message: err.Error()}), nil + return Response(400, model.Error{Message: err.Error()}), nil } result, err := s.coreApi.UpsertInferenceService(entity) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(201, result), nil - // TODO: return Response(400, Error{}), nil // TODO: return Response(401, Error{}), nil } @@ -62,32 +65,39 @@ func (s *ModelRegistryServiceAPIService) CreateInferenceService(ctx context.Cont func (s *ModelRegistryServiceAPIService) CreateInferenceServiceServe(ctx context.Context, inferenceserviceId string, serveModelCreate model.ServeModelCreate) (ImplResponse, error) { entity, err := s.converter.ConvertServeModelCreate(&serveModelCreate) if err != nil { - return Response(500, model.Error{Message: err.Error()}), nil + return Response(400, model.Error{Message: err.Error()}), nil } result, err := s.coreApi.UpsertServeModel(entity, &inferenceserviceId) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } + return Response(500, model.Error{Message: err.Error()}), nil } return Response(201, result), nil - // TODO: return Response(400, Error{}), nil // TODO: return Response(401, Error{}), nil - // TODO: return Response(404, Error{}), nil } // CreateModelArtifact - Create a ModelArtifact func (s *ModelRegistryServiceAPIService) CreateModelArtifact(ctx context.Context, modelArtifactCreate model.ModelArtifactCreate) (ImplResponse, error) { entity, err := s.converter.ConvertModelArtifactCreate(&modelArtifactCreate) if err != nil { - return Response(500, model.Error{Message: err.Error()}), nil + return Response(400, model.Error{Message: err.Error()}), nil } result, err := s.coreApi.UpsertModelArtifact(entity, nil) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(201, result), nil - // TODO: return Response(400, Error{}), nil // TODO: return Response(401, Error{}), nil } @@ -95,15 +105,17 @@ func (s *ModelRegistryServiceAPIService) CreateModelArtifact(ctx context.Context func (s *ModelRegistryServiceAPIService) CreateModelVersion(ctx context.Context, modelVersionCreate model.ModelVersionCreate) (ImplResponse, error) { modelVersion, err := s.converter.ConvertModelVersionCreate(&modelVersionCreate) if err != nil { - return Response(500, model.Error{Message: err.Error()}), nil + return Response(400, model.Error{Message: err.Error()}), nil } result, err := s.coreApi.UpsertModelVersion(modelVersion, &modelVersionCreate.RegisteredModelId) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(201, result), nil - // TODO: return Response(400, Error{}), nil // TODO: return Response(401, Error{}), nil } @@ -111,28 +123,32 @@ func (s *ModelRegistryServiceAPIService) CreateModelVersion(ctx context.Context, func (s *ModelRegistryServiceAPIService) CreateModelVersionArtifact(ctx context.Context, modelversionId string, artifact model.Artifact) (ImplResponse, error) { result, err := s.coreApi.UpsertArtifact(&artifact, &modelversionId) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(201, result), nil // return Response(http.StatusNotImplemented, nil), errors.New("unsupported artifactType") // TODO return Response(200, Artifact{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // CreateRegisteredModel - Create a RegisteredModel func (s *ModelRegistryServiceAPIService) CreateRegisteredModel(ctx context.Context, registeredModelCreate model.RegisteredModelCreate) (ImplResponse, error) { registeredModel, err := s.converter.ConvertRegisteredModelCreate(®isteredModelCreate) if err != nil { - return Response(500, model.Error{Message: err.Error()}), nil + return Response(400, model.Error{Message: err.Error()}), nil } result, err := s.coreApi.UpsertRegisteredModel(registeredModel) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(201, result), nil - // TODO: return Response(400, Error{}), nil // TODO: return Response(401, Error{}), nil } @@ -140,27 +156,33 @@ func (s *ModelRegistryServiceAPIService) CreateRegisteredModel(ctx context.Conte func (s *ModelRegistryServiceAPIService) CreateRegisteredModelVersion(ctx context.Context, registeredmodelId string, modelVersion model.ModelVersion) (ImplResponse, error) { result, err := s.coreApi.UpsertModelVersion(&modelVersion, apiutils.StrPtr(registeredmodelId)) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(201, result), nil - // TODO return Response(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // CreateServingEnvironment - Create a ServingEnvironment func (s *ModelRegistryServiceAPIService) CreateServingEnvironment(ctx context.Context, servingEnvironmentCreate model.ServingEnvironmentCreate) (ImplResponse, error) { entity, err := s.converter.ConvertServingEnvironmentCreate(&servingEnvironmentCreate) if err != nil { - return Response(500, model.Error{Message: err.Error()}), nil + return Response(400, model.Error{Message: err.Error()}), nil } result, err := s.coreApi.UpsertServingEnvironment(entity) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(201, result), nil - // TODO: return Response(400, Error{}), nil // TODO: return Response(401, Error{}), nil } @@ -168,59 +190,77 @@ func (s *ModelRegistryServiceAPIService) CreateServingEnvironment(ctx context.Co func (s *ModelRegistryServiceAPIService) FindInferenceService(ctx context.Context, name string, externalId string, parentResourceId string) (ImplResponse, error) { result, err := s.coreApi.GetInferenceServiceByParams(apiutils.StrPtr(name), apiutils.StrPtr(parentResourceId), apiutils.StrPtr(externalId)) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil - // TODO return esponse(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // FindModelArtifact - Get a ModelArtifact that matches search parameters. func (s *ModelRegistryServiceAPIService) FindModelArtifact(ctx context.Context, name string, externalId string, parentResourceId string) (ImplResponse, error) { result, err := s.coreApi.GetModelArtifactByParams(apiutils.StrPtr(name), apiutils.StrPtr(parentResourceId), apiutils.StrPtr(externalId)) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil - // TODO return esponse(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // FindModelVersion - Get a ModelVersion that matches search parameters. func (s *ModelRegistryServiceAPIService) FindModelVersion(ctx context.Context, name string, externalId string, registeredModelId string) (ImplResponse, error) { result, err := s.coreApi.GetModelVersionByParams(apiutils.StrPtr(name), apiutils.StrPtr(registeredModelId), apiutils.StrPtr(externalId)) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil - // TODO return esponse(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // FindRegisteredModel - Get a RegisteredModel that matches search parameters. func (s *ModelRegistryServiceAPIService) FindRegisteredModel(ctx context.Context, name string, externalID string) (ImplResponse, error) { result, err := s.coreApi.GetRegisteredModelByParams(apiutils.StrPtr(name), apiutils.StrPtr(externalID)) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil - // TODO return esponse(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // FindServingEnvironment - Find ServingEnvironment func (s *ModelRegistryServiceAPIService) FindServingEnvironment(ctx context.Context, name string, externalID string) (ImplResponse, error) { result, err := s.coreApi.GetServingEnvironmentByParams(apiutils.StrPtr(name), apiutils.StrPtr(externalID)) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // GetEnvironmentInferenceServices - List All ServingEnvironment's InferenceServices @@ -231,33 +271,39 @@ func (s *ModelRegistryServiceAPIService) GetEnvironmentInferenceServices(ctx con } result, err := s.coreApi.GetInferenceServices(listOpts, apiutils.StrPtr(servingenvironmentId), nil) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // GetInferenceService - Get a InferenceService func (s *ModelRegistryServiceAPIService) GetInferenceService(ctx context.Context, inferenceserviceId string) (ImplResponse, error) { result, err := s.coreApi.GetInferenceServiceById(inferenceserviceId) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO: return Response(401, Error{}), nil - // TODO: return Response(404, Error{}), nil } // GetInferenceServiceModel - Get InferenceService's RegisteredModel func (s *ModelRegistryServiceAPIService) GetInferenceServiceModel(ctx context.Context, inferenceserviceId string) (ImplResponse, error) { result, err := s.coreApi.GetRegisteredModelByInferenceService(inferenceserviceId) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO: return Response(401, Error{}), nil - // TODO: return Response(404, Error{}), nil } // GetInferenceServiceServes - List All InferenceService's ServeModel actions @@ -268,22 +314,26 @@ func (s *ModelRegistryServiceAPIService) GetInferenceServiceServes(ctx context.C } result, err := s.coreApi.GetServeModels(listOpts, apiutils.StrPtr(inferenceserviceId)) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // GetInferenceServiceVersion - Get InferenceService's ModelVersion func (s *ModelRegistryServiceAPIService) GetInferenceServiceVersion(ctx context.Context, inferenceserviceId string) (ImplResponse, error) { result, err := s.coreApi.GetModelVersionByInferenceService(inferenceserviceId) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO: return Response(401, Error{}), nil - // TODO: return Response(404, Error{}), nil } // GetInferenceServices - List All InferenceServices @@ -294,22 +344,26 @@ func (s *ModelRegistryServiceAPIService) GetInferenceServices(ctx context.Contex } result, err := s.coreApi.GetInferenceServices(listOpts, nil, nil) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // GetModelArtifact - Get a ModelArtifact func (s *ModelRegistryServiceAPIService) GetModelArtifact(ctx context.Context, modelartifactId string) (ImplResponse, error) { result, err := s.coreApi.GetModelArtifactById(modelartifactId) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO: return Response(401, Error{}), nil - // TODO: return Response(404, Error{}), nil } // GetModelArtifacts - List All ModelArtifacts @@ -320,23 +374,29 @@ func (s *ModelRegistryServiceAPIService) GetModelArtifacts(ctx context.Context, } result, err := s.coreApi.GetModelArtifacts(listOpts, nil) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil - // TODO return Response(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // GetModelVersion - Get a ModelVersion func (s *ModelRegistryServiceAPIService) GetModelVersion(ctx context.Context, modelversionId string) (ImplResponse, error) { result, err := s.coreApi.GetModelVersionById(modelversionId) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO: return Response(401, Error{}), nil - // TODO: return Response(404, Error{}), nil } // GetModelVersionArtifacts - List All ModelVersion's artifacts @@ -349,11 +409,13 @@ func (s *ModelRegistryServiceAPIService) GetModelVersionArtifacts(ctx context.Co } result, err := s.coreApi.GetArtifacts(listOpts, apiutils.StrPtr(modelversionId)) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // GetModelVersions - List All ModelVersions @@ -364,23 +426,29 @@ func (s *ModelRegistryServiceAPIService) GetModelVersions(ctx context.Context, p } result, err := s.coreApi.GetModelVersions(listOpts, nil) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil - // TODO return Response(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // GetRegisteredModel - Get a RegisteredModel func (s *ModelRegistryServiceAPIService) GetRegisteredModel(ctx context.Context, registeredmodelId string) (ImplResponse, error) { result, err := s.coreApi.GetRegisteredModelById(registeredmodelId) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO: return Response(401, Error{}), nil - // TODO: return Response(404, Error{}), nil } // GetRegisteredModelVersions - List All RegisteredModel's ModelVersions @@ -393,11 +461,13 @@ func (s *ModelRegistryServiceAPIService) GetRegisteredModelVersions(ctx context. } result, err := s.coreApi.GetModelVersions(listOpts, apiutils.StrPtr(registeredmodelId)) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // GetRegisteredModels - List All RegisteredModels @@ -408,23 +478,29 @@ func (s *ModelRegistryServiceAPIService) GetRegisteredModels(ctx context.Context } result, err := s.coreApi.GetRegisteredModels(listOpts) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil - // TODO return Response(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // GetServingEnvironment - Get a ServingEnvironment func (s *ModelRegistryServiceAPIService) GetServingEnvironment(ctx context.Context, servingenvironmentId string) (ImplResponse, error) { result, err := s.coreApi.GetServingEnvironmentById(servingenvironmentId) if err != nil { + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil // TODO: return Response(401, Error{}), nil - // TODO: return Response(404, Error{}), nil } // GetServingEnvironments - List All ServingEnvironments @@ -445,83 +521,103 @@ func (s *ModelRegistryServiceAPIService) GetServingEnvironments(ctx context.Cont func (s *ModelRegistryServiceAPIService) UpdateInferenceService(ctx context.Context, inferenceserviceId string, inferenceServiceUpdate model.InferenceServiceUpdate) (ImplResponse, error) { entity, err := s.converter.ConvertInferenceServiceUpdate(&inferenceServiceUpdate) if err != nil { - return Response(500, model.Error{Message: err.Error()}), nil + return Response(400, model.Error{Message: err.Error()}), nil } entity.Id = &inferenceserviceId result, err := s.coreApi.UpsertInferenceService(entity) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil - // TODO return Response(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // UpdateModelArtifact - Update a ModelArtifact func (s *ModelRegistryServiceAPIService) UpdateModelArtifact(ctx context.Context, modelartifactId string, modelArtifactUpdate model.ModelArtifactUpdate) (ImplResponse, error) { modelArtifact, err := s.converter.ConvertModelArtifactUpdate(&modelArtifactUpdate) if err != nil { - return Response(500, model.Error{Message: err.Error()}), nil + return Response(400, model.Error{Message: err.Error()}), nil } modelArtifact.Id = &modelartifactId result, err := s.coreApi.UpsertModelArtifact(modelArtifact, nil) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil - // TODO return Response(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // UpdateModelVersion - Update a ModelVersion func (s *ModelRegistryServiceAPIService) UpdateModelVersion(ctx context.Context, modelversionId string, modelVersionUpdate model.ModelVersionUpdate) (ImplResponse, error) { modelVersion, err := s.converter.ConvertModelVersionUpdate(&modelVersionUpdate) if err != nil { - return Response(500, model.Error{Message: err.Error()}), nil + return Response(400, model.Error{Message: err.Error()}), nil } modelVersion.Id = &modelversionId result, err := s.coreApi.UpsertModelVersion(modelVersion, nil) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil - // TODO return Response(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // UpdateRegisteredModel - Update a RegisteredModel func (s *ModelRegistryServiceAPIService) UpdateRegisteredModel(ctx context.Context, registeredmodelId string, registeredModelUpdate model.RegisteredModelUpdate) (ImplResponse, error) { registeredModel, err := s.converter.ConvertRegisteredModelUpdate(®isteredModelUpdate) if err != nil { - return Response(500, model.Error{Message: err.Error()}), nil + return Response(400, model.Error{Message: err.Error()}), nil } registeredModel.Id = ®isteredmodelId result, err := s.coreApi.UpsertRegisteredModel(registeredModel) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil - // TODO return Response(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } // UpdateServingEnvironment - Update a ServingEnvironment func (s *ModelRegistryServiceAPIService) UpdateServingEnvironment(ctx context.Context, servingenvironmentId string, servingEnvironmentUpdate model.ServingEnvironmentUpdate) (ImplResponse, error) { entity, err := s.converter.ConvertServingEnvironmentUpdate(&servingEnvironmentUpdate) if err != nil { - return Response(500, model.Error{Message: err.Error()}), nil + return Response(400, model.Error{Message: err.Error()}), nil } entity.Id = &servingenvironmentId result, err := s.coreApi.UpsertServingEnvironment(entity) if err != nil { + if errors.Is(err, api.ErrBadRequest) { + return Response(400, model.Error{Message: err.Error()}), nil + } + if errors.Is(err, api.ErrNotFound) { + return Response(404, model.Error{Message: err.Error()}), nil + } return Response(500, model.Error{Message: err.Error()}), nil } return Response(200, result), nil - // TODO return Response(400, Error{}), nil // TODO return Response(401, Error{}), nil - // TODO return Response(404, Error{}), nil } diff --git a/pkg/api/error.go b/pkg/api/error.go new file mode 100644 index 00000000..9e83e3ae --- /dev/null +++ b/pkg/api/error.go @@ -0,0 +1,8 @@ +package api + +import "errors" + +var ( + ErrBadRequest = errors.New("bad request") + ErrNotFound = errors.New("not found") +) diff --git a/pkg/core/core.go b/pkg/core/core.go index afb0eaf8..7d41873b 100644 --- a/pkg/core/core.go +++ b/pkg/core/core.go @@ -56,48 +56,48 @@ func BuildTypesMap(cc grpc.ClientConnInterface, nameConfig mlmdtypes.MLMDTypeNam } registeredModelResp, err := client.GetContextType(context.Background(), ®isteredModelContextTypeReq) if err != nil { - return nil, fmt.Errorf("error getting context type %s: %v", nameConfig.RegisteredModelTypeName, err) + return nil, fmt.Errorf("error getting context type %s: %w", nameConfig.RegisteredModelTypeName, err) } modelVersionContextTypeReq := proto.GetContextTypeRequest{ TypeName: &nameConfig.ModelVersionTypeName, } modelVersionResp, err := client.GetContextType(context.Background(), &modelVersionContextTypeReq) if err != nil { - return nil, fmt.Errorf("error getting context type %s: %v", nameConfig.ModelVersionTypeName, err) + return nil, fmt.Errorf("error getting context type %s: %w", nameConfig.ModelVersionTypeName, err) } docArtifactResp, err := client.GetArtifactType(context.Background(), &proto.GetArtifactTypeRequest{ TypeName: &nameConfig.DocArtifactTypeName, }) if err != nil { - return nil, fmt.Errorf("error getting artifact type %s: %v", nameConfig.DocArtifactTypeName, err) + return nil, fmt.Errorf("error getting artifact type %s: %w", nameConfig.DocArtifactTypeName, err) } modelArtifactArtifactTypeReq := proto.GetArtifactTypeRequest{ TypeName: &nameConfig.ModelArtifactTypeName, } modelArtifactResp, err := client.GetArtifactType(context.Background(), &modelArtifactArtifactTypeReq) if err != nil { - return nil, fmt.Errorf("error getting artifact type %s: %v", nameConfig.ModelArtifactTypeName, err) + return nil, fmt.Errorf("error getting artifact type %s: %w", nameConfig.ModelArtifactTypeName, err) } servingEnvironmentContextTypeReq := proto.GetContextTypeRequest{ TypeName: &nameConfig.ServingEnvironmentTypeName, } servingEnvironmentResp, err := client.GetContextType(context.Background(), &servingEnvironmentContextTypeReq) if err != nil { - return nil, fmt.Errorf("error getting context type %s: %v", nameConfig.ServingEnvironmentTypeName, err) + return nil, fmt.Errorf("error getting context type %s: %w", nameConfig.ServingEnvironmentTypeName, err) } inferenceServiceContextTypeReq := proto.GetContextTypeRequest{ TypeName: &nameConfig.InferenceServiceTypeName, } inferenceServiceResp, err := client.GetContextType(context.Background(), &inferenceServiceContextTypeReq) if err != nil { - return nil, fmt.Errorf("error getting context type %s: %v", nameConfig.InferenceServiceTypeName, err) + return nil, fmt.Errorf("error getting context type %s: %w", nameConfig.InferenceServiceTypeName, err) } serveModelExecutionReq := proto.GetExecutionTypeRequest{ TypeName: &nameConfig.ServeModelTypeName, } serveModelResp, err := client.GetExecutionType(context.Background(), &serveModelExecutionReq) if err != nil { - return nil, fmt.Errorf("error getting execution type %s: %v", nameConfig.ServeModelTypeName, err) + return nil, fmt.Errorf("error getting execution type %s: %w", nameConfig.ServeModelTypeName, err) } typesMap := map[string]int64{ @@ -131,7 +131,7 @@ func (serv *ModelRegistryService) UpsertRegisteredModel(registeredModel *openapi withNotEditable, err := serv.openapiConv.OverrideNotEditableForRegisteredModel(converter.NewOpenapiUpdateWrapper(existing, registeredModel)) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } registeredModel = &withNotEditable } @@ -165,7 +165,7 @@ func (serv *ModelRegistryService) GetRegisteredModelById(id string) (*openapi.Re idAsInt, err := converter.StringToInt64(&id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ @@ -176,16 +176,16 @@ func (serv *ModelRegistryService) GetRegisteredModelById(id string) (*openapi.Re } if len(getByIdResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple registered models found for id %s", id) + return nil, fmt.Errorf("multiple registered models found for id %s: %w", id, api.ErrNotFound) } if len(getByIdResp.Contexts) == 0 { - return nil, fmt.Errorf("no registered model found for id %s", id) + return nil, fmt.Errorf("no registered model found for id %s: %w", id, api.ErrNotFound) } regModel, err := serv.mapper.MapToRegisteredModel(getByIdResp.Contexts[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return regModel, nil @@ -206,7 +206,7 @@ func (serv *ModelRegistryService) getRegisteredModelByVersionId(id string) (*ope idAsInt, err := converter.StringToInt64(&id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } getParentResp, err := serv.mlmdClient.GetParentContextsByContext(context.Background(), &proto.GetParentContextsByContextRequest{ @@ -217,16 +217,16 @@ func (serv *ModelRegistryService) getRegisteredModelByVersionId(id string) (*ope } if len(getParentResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple registered models found for model version %s", id) + return nil, fmt.Errorf("multiple registered models found for model version %s: %w", id, api.ErrNotFound) } if len(getParentResp.Contexts) == 0 { - return nil, fmt.Errorf("no registered model found for model version %s", id) + return nil, fmt.Errorf("no registered model found for model version %s: %w", id, api.ErrNotFound) } regModel, err := serv.mapper.MapToRegisteredModel(getParentResp.Contexts[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return regModel, nil @@ -243,7 +243,7 @@ func (serv *ModelRegistryService) GetRegisteredModelByParams(name *string, exter } else if externalId != nil { filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) } else { - return nil, fmt.Errorf("invalid parameters call, supply either name or externalId") + return nil, fmt.Errorf("invalid parameters call, supply either name or externalId: %w", api.ErrBadRequest) } glog.Info("filterQuery ", filterQuery) @@ -258,16 +258,16 @@ func (serv *ModelRegistryService) GetRegisteredModelByParams(name *string, exter } if len(getByParamsResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple registered models found for name=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId)) + return nil, fmt.Errorf("multiple registered models found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound) } if len(getByParamsResp.Contexts) == 0 { - return nil, fmt.Errorf("no registered models found for name=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId)) + return nil, fmt.Errorf("no registered models found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound) } regModel, err := serv.mapper.MapToRegisteredModel(getByParamsResp.Contexts[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return regModel, nil } @@ -317,7 +317,7 @@ func (serv *ModelRegistryService) UpsertModelVersion(modelVersion *openapi.Model // create glog.Info("Creating new model version") if registeredModelId == nil { - return nil, fmt.Errorf("missing registered model id, cannot create model version without registered model") + return nil, fmt.Errorf("missing registered model id, cannot create model version without registered model: %w", api.ErrBadRequest) } registeredModel, err = serv.GetRegisteredModelById(*registeredModelId) if err != nil { @@ -333,7 +333,7 @@ func (serv *ModelRegistryService) UpsertModelVersion(modelVersion *openapi.Model withNotEditable, err := serv.openapiConv.OverrideNotEditableForModelVersion(converter.NewOpenapiUpdateWrapper(existing, modelVersion)) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } modelVersion = &withNotEditable @@ -345,7 +345,7 @@ func (serv *ModelRegistryService) UpsertModelVersion(modelVersion *openapi.Model modelCtx, err := serv.mapper.MapFromModelVersion(modelVersion, *registeredModel.Id, registeredModel.Name) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } modelCtxResp, err := serv.mlmdClient.PutContexts(context.Background(), &proto.PutContextsRequest{ @@ -361,7 +361,7 @@ func (serv *ModelRegistryService) UpsertModelVersion(modelVersion *openapi.Model if modelVersion.Id == nil { registeredModelId, err := converter.StringToInt64(registeredModel.Id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } _, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{ @@ -389,7 +389,7 @@ func (serv *ModelRegistryService) UpsertModelVersion(modelVersion *openapi.Model func (serv *ModelRegistryService) GetModelVersionById(id string) (*openapi.ModelVersion, error) { idAsInt, err := converter.StringToInt64(&id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ @@ -400,16 +400,16 @@ func (serv *ModelRegistryService) GetModelVersionById(id string) (*openapi.Model } if len(getByIdResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple model versions found for id %s", id) + return nil, fmt.Errorf("multiple model versions found for id %s: %w", id, api.ErrNotFound) } if len(getByIdResp.Contexts) == 0 { - return nil, fmt.Errorf("no model version found for id %s", id) + return nil, fmt.Errorf("no model version found for id %s: %w", id, api.ErrNotFound) } modelVer, err := serv.mapper.MapToModelVersion(getByIdResp.Contexts[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return modelVer, nil @@ -432,7 +432,7 @@ func (serv *ModelRegistryService) GetModelVersionByInferenceService(inferenceSer return nil, err } if len(versions.Items) == 0 { - return nil, fmt.Errorf("no model versions found for id %s", is.RegisteredModelId) + return nil, fmt.Errorf("no model versions found for id %s: %w", is.RegisteredModelId, api.ErrNotFound) } return &versions.Items[0], nil } @@ -443,7 +443,7 @@ func (serv *ModelRegistryService) getModelVersionByArtifactId(id string) (*opena idAsInt, err := converter.StringToInt64(&id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } getParentResp, err := serv.mlmdClient.GetContextsByArtifact(context.Background(), &proto.GetContextsByArtifactRequest{ @@ -454,16 +454,16 @@ func (serv *ModelRegistryService) getModelVersionByArtifactId(id string) (*opena } if len(getParentResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple model versions found for artifact %s", id) + return nil, fmt.Errorf("multiple model versions found for artifact %s: %w", id, api.ErrNotFound) } if len(getParentResp.Contexts) == 0 { - return nil, fmt.Errorf("no model version found for artifact %s", id) + return nil, fmt.Errorf("no model version found for artifact %s: %w", id, api.ErrNotFound) } modelVersion, err := serv.mapper.MapToModelVersion(getParentResp.Contexts[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return modelVersion, nil @@ -478,7 +478,7 @@ func (serv *ModelRegistryService) GetModelVersionByParams(versionName *string, r } else if externalId != nil { filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) } else { - return nil, fmt.Errorf("invalid parameters call, supply either (versionName and registeredModelId), or externalId") + return nil, fmt.Errorf("invalid parameters call, supply either (versionName and registeredModelId), or externalId: %w", api.ErrBadRequest) } getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ @@ -492,16 +492,16 @@ func (serv *ModelRegistryService) GetModelVersionByParams(versionName *string, r } if len(getByParamsResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple model versions found for versionName=%v, registeredModelId=%v, externalId=%v", apiutils.ZeroIfNil(versionName), apiutils.ZeroIfNil(registeredModelId), apiutils.ZeroIfNil(externalId)) + return nil, fmt.Errorf("multiple model versions found for versionName=%v, registeredModelId=%v, externalId=%v: %w", apiutils.ZeroIfNil(versionName), apiutils.ZeroIfNil(registeredModelId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) } if len(getByParamsResp.Contexts) == 0 { - return nil, fmt.Errorf("no model versions found for versionName=%v, registeredModelId=%v, externalId=%v", apiutils.ZeroIfNil(versionName), apiutils.ZeroIfNil(registeredModelId), apiutils.ZeroIfNil(externalId)) + return nil, fmt.Errorf("no model versions found for versionName=%v, registeredModelId=%v, externalId=%v: %w", apiutils.ZeroIfNil(versionName), apiutils.ZeroIfNil(registeredModelId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) } modelVer, err := serv.mapper.MapToModelVersion(getByParamsResp.Contexts[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return modelVer, nil } @@ -510,7 +510,7 @@ func (serv *ModelRegistryService) GetModelVersionByParams(versionName *string, r func (serv *ModelRegistryService) GetModelVersions(listOptions api.ListOptions, registeredModelId *string) (*openapi.ModelVersionList, error) { listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } if registeredModelId != nil { @@ -530,7 +530,7 @@ func (serv *ModelRegistryService) GetModelVersions(listOptions api.ListOptions, for _, c := range contextsResp.Contexts { mapped, err := serv.mapper.MapToModelVersion(c) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } results = append(results, *mapped) } @@ -560,11 +560,11 @@ func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, mod creating = true glog.Info("Creating model artifact") if modelVersionId == nil { - return nil, fmt.Errorf("missing model version id, cannot create artifact without model version") + return nil, fmt.Errorf("missing model version id, cannot create artifact without model version: %w", api.ErrBadRequest) } _, err := serv.GetModelVersionById(*modelVersionId) if err != nil { - return nil, fmt.Errorf("no model version found for id %s", *modelVersionId) + return nil, fmt.Errorf("no model version found for id %s: %w", *modelVersionId, api.ErrNotFound) } } else { glog.Info("Updating model artifact") @@ -575,7 +575,7 @@ func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, mod withNotEditable, err := serv.openapiConv.OverrideNotEditableForModelArtifact(converter.NewOpenapiUpdateWrapper(existing, ma)) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } ma = &withNotEditable @@ -589,11 +589,11 @@ func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, mod creating = true glog.Info("Creating doc artifact") if modelVersionId == nil { - return nil, fmt.Errorf("missing model version id, cannot create artifact without model version") + return nil, fmt.Errorf("missing model version id, cannot create artifact without model version: %w", api.ErrBadRequest) } _, err := serv.GetModelVersionById(*modelVersionId) if err != nil { - return nil, fmt.Errorf("no model version found for id %s", *modelVersionId) + return nil, fmt.Errorf("no model version found for id %s: %w", *modelVersionId, api.ErrNotFound) } } else { glog.Info("Updating doc artifact") @@ -602,12 +602,12 @@ func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, mod return nil, err } if existing.DocArtifact == nil { - return nil, fmt.Errorf("mismatched types, artifact with id %s is not a doc artifact", *da.Id) + return nil, fmt.Errorf("mismatched types, artifact with id %s is not a doc artifact: %w", *da.Id, api.ErrBadRequest) } withNotEditable, err := serv.openapiConv.OverrideNotEditableForDocArtifact(converter.NewOpenapiUpdateWrapper(existing.DocArtifact, da)) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } da = &withNotEditable @@ -617,11 +617,11 @@ func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, mod } } } else { - return nil, fmt.Errorf("invalid artifact type, must be either ModelArtifact or DocArtifact") + return nil, fmt.Errorf("invalid artifact type, must be either ModelArtifact or DocArtifact: %w", api.ErrBadRequest) } pa, err := serv.mapper.MapFromArtifact(artifact, modelVersionId) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } artifactsResp, err := serv.mlmdClient.PutArtifacts(context.Background(), &proto.PutArtifactsRequest{ Artifacts: []*proto.Artifact{pa}, @@ -634,7 +634,7 @@ func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, mod // add explicit Attribution between Artifact and ModelVersion modelVersionId, err := converter.StringToInt64(modelVersionId) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } attributions := []*proto.Attribution{} for _, a := range artifactsResp.ArtifactIds { @@ -659,7 +659,7 @@ func (serv *ModelRegistryService) UpsertArtifact(artifact *openapi.Artifact, mod func (serv *ModelRegistryService) GetArtifactById(id string) (*openapi.Artifact, error) { idAsInt, err := converter.StringToInt64(&id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } artifactsResp, err := serv.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{ @@ -669,10 +669,10 @@ func (serv *ModelRegistryService) GetArtifactById(id string) (*openapi.Artifact, return nil, err } if len(artifactsResp.Artifacts) > 1 { - return nil, fmt.Errorf("multiple artifacts found for id %s", id) + return nil, fmt.Errorf("multiple artifacts found for id %s: %w", id, api.ErrNotFound) } if len(artifactsResp.Artifacts) == 0 { - return nil, fmt.Errorf("no artifact found for id %s", id) + return nil, fmt.Errorf("no artifact found for id %s: %w", id, api.ErrNotFound) } return serv.mapper.MapToArtifact(artifactsResp.Artifacts[0]) } @@ -680,16 +680,16 @@ func (serv *ModelRegistryService) GetArtifactById(id string) (*openapi.Artifact, func (serv *ModelRegistryService) GetArtifacts(listOptions api.ListOptions, modelVersionId *string) (*openapi.ArtifactList, error) { listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } var artifacts []*proto.Artifact var nextPageToken *string if modelVersionId == nil { - return nil, fmt.Errorf("missing model version id, cannot get artifacts without model version") + return nil, fmt.Errorf("missing model version id, cannot get artifacts without model version: %w", api.ErrBadRequest) } ctxId, err := converter.StringToInt64(modelVersionId) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{ ContextId: ctxId, @@ -705,7 +705,7 @@ func (serv *ModelRegistryService) GetArtifacts(listOptions api.ListOptions, mode for _, a := range artifacts { mapped, err := serv.mapper.MapToArtifact(a) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } results = append(results, *mapped) } @@ -743,7 +743,7 @@ func (serv *ModelRegistryService) GetModelArtifactById(id string) (*openapi.Mode } ma := art.ModelArtifact if ma == nil { - return nil, fmt.Errorf("artifact with id %s is not a model artifact", id) + return nil, fmt.Errorf("artifact with id %s is not a model artifact: %w", id, api.ErrNotFound) } return ma, err } @@ -761,7 +761,7 @@ func (serv *ModelRegistryService) GetModelArtifactByInferenceService(inferenceSe } if artifactList.Size == 0 { - return nil, fmt.Errorf("no artifacts found for model version %s", *mv.Id) + return nil, fmt.Errorf("no artifacts found for model version %s: %w", *mv.Id, api.ErrNotFound) } return &artifactList.Items[0], nil @@ -778,7 +778,7 @@ func (serv *ModelRegistryService) GetModelArtifactByParams(artifactName *string, } else if artifactName != nil && modelVersionId != nil { filterQuery = fmt.Sprintf("name = \"%s\"", converter.PrefixWhenOwned(modelVersionId, *artifactName)) } else { - return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and modelVersionId), or externalId") + return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and modelVersionId), or externalId: %w", api.ErrBadRequest) } glog.Info("filterQuery ", filterQuery) @@ -793,18 +793,18 @@ func (serv *ModelRegistryService) GetModelArtifactByParams(artifactName *string, } if len(artifactsResponse.Artifacts) > 1 { - return nil, fmt.Errorf("multiple model artifacts found for artifactName=%v, modelVersionId=%v, externalId=%v", apiutils.ZeroIfNil(artifactName), apiutils.ZeroIfNil(modelVersionId), apiutils.ZeroIfNil(externalId)) + return nil, fmt.Errorf("multiple model artifacts found for artifactName=%v, modelVersionId=%v, externalId=%v: %w", apiutils.ZeroIfNil(artifactName), apiutils.ZeroIfNil(modelVersionId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) } if len(artifactsResponse.Artifacts) == 0 { - return nil, fmt.Errorf("no model artifacts found for artifactName=%v, modelVersionId=%v, externalId=%v", apiutils.ZeroIfNil(artifactName), apiutils.ZeroIfNil(modelVersionId), apiutils.ZeroIfNil(externalId)) + return nil, fmt.Errorf("no model artifacts found for artifactName=%v, modelVersionId=%v, externalId=%v: %w", apiutils.ZeroIfNil(artifactName), apiutils.ZeroIfNil(modelVersionId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) } artifact0 = artifactsResponse.Artifacts[0] result, err := serv.mapper.MapToModelArtifact(artifact0) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return result, nil @@ -814,7 +814,7 @@ func (serv *ModelRegistryService) GetModelArtifactByParams(artifactName *string, func (serv *ModelRegistryService) GetModelArtifacts(listOptions api.ListOptions, modelVersionId *string) (*openapi.ModelArtifactList, error) { listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } var artifacts []*proto.Artifact @@ -822,7 +822,7 @@ func (serv *ModelRegistryService) GetModelArtifacts(listOptions api.ListOptions, if modelVersionId != nil { ctxId, err := converter.StringToInt64(modelVersionId) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{ ContextId: ctxId, @@ -849,7 +849,7 @@ func (serv *ModelRegistryService) GetModelArtifacts(listOptions api.ListOptions, for _, a := range artifacts { mapped, err := serv.mapper.MapToModelArtifact(a) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } results = append(results, *mapped) } @@ -882,14 +882,14 @@ func (serv *ModelRegistryService) UpsertServingEnvironment(servingEnvironment *o withNotEditable, err := serv.openapiConv.OverrideNotEditableForServingEnvironment(converter.NewOpenapiUpdateWrapper(existing, servingEnvironment)) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } servingEnvironment = &withNotEditable } protoCtx, err := serv.mapper.MapFromServingEnvironment(servingEnvironment) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } protoCtxResp, err := serv.mlmdClient.PutContexts(context.Background(), &proto.PutContextsRequest{ @@ -904,7 +904,7 @@ func (serv *ModelRegistryService) UpsertServingEnvironment(servingEnvironment *o idAsString := converter.Int64ToString(&protoCtxResp.ContextIds[0]) openapiModel, err := serv.GetServingEnvironmentById(*idAsString) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return openapiModel, nil @@ -916,7 +916,7 @@ func (serv *ModelRegistryService) GetServingEnvironmentById(id string) (*openapi idAsInt, err := converter.StringToInt64(&id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ @@ -927,16 +927,16 @@ func (serv *ModelRegistryService) GetServingEnvironmentById(id string) (*openapi } if len(getByIdResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple serving environments found for id %s", id) + return nil, fmt.Errorf("multiple serving environments found for id %s: %w", id, api.ErrNotFound) } if len(getByIdResp.Contexts) == 0 { - return nil, fmt.Errorf("no serving environment found for id %s", id) + return nil, fmt.Errorf("no serving environment found for id %s: %w", id, api.ErrNotFound) } openapiModel, err := serv.mapper.MapToServingEnvironment(getByIdResp.Contexts[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return openapiModel, nil @@ -953,7 +953,7 @@ func (serv *ModelRegistryService) GetServingEnvironmentByParams(name *string, ex } else if externalId != nil { filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) } else { - return nil, fmt.Errorf("invalid parameters call, supply either name or externalId") + return nil, fmt.Errorf("invalid parameters call, supply either name or externalId: %w", api.ErrBadRequest) } getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ @@ -967,16 +967,16 @@ func (serv *ModelRegistryService) GetServingEnvironmentByParams(name *string, ex } if len(getByParamsResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple serving environments found for name=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId)) + return nil, fmt.Errorf("multiple serving environments found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound) } if len(getByParamsResp.Contexts) == 0 { - return nil, fmt.Errorf("no serving environments found for name=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId)) + return nil, fmt.Errorf("no serving environments found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound) } openapiModel, err := serv.mapper.MapToServingEnvironment(getByParamsResp.Contexts[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return openapiModel, nil } @@ -985,7 +985,7 @@ func (serv *ModelRegistryService) GetServingEnvironmentByParams(name *string, ex func (serv *ModelRegistryService) GetServingEnvironments(listOptions api.ListOptions) (*openapi.ServingEnvironmentList, error) { listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } contextsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ TypeName: &serv.nameConfig.ServingEnvironmentTypeName, @@ -999,7 +999,7 @@ func (serv *ModelRegistryService) GetServingEnvironments(listOptions api.ListOpt for _, c := range contextsResp.Contexts { mapped, err := serv.mapper.MapToServingEnvironment(c) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } results = append(results, *mapped) } @@ -1064,7 +1064,7 @@ func (serv *ModelRegistryService) UpsertInferenceService(inferenceService *opena protoCtx, err := serv.mapper.MapFromInferenceService(inferenceService, *servingEnvironment.Id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } protoCtxResp, err := serv.mlmdClient.PutContexts(context.Background(), &proto.PutContextsRequest{ @@ -1080,7 +1080,7 @@ func (serv *ModelRegistryService) UpsertInferenceService(inferenceService *opena if inferenceService.Id == nil { servingEnvironmentId, err := converter.StringToInt64(servingEnvironment.Id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } _, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{ @@ -1110,7 +1110,7 @@ func (serv *ModelRegistryService) getServingEnvironmentByInferenceServiceId(id s idAsInt, err := converter.StringToInt64(&id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } getParentResp, err := serv.mlmdClient.GetParentContextsByContext(context.Background(), &proto.GetParentContextsByContextRequest{ @@ -1121,16 +1121,16 @@ func (serv *ModelRegistryService) getServingEnvironmentByInferenceServiceId(id s } if len(getParentResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple ServingEnvironments found for InferenceService %s", id) + return nil, fmt.Errorf("multiple ServingEnvironments found for InferenceService %s: %w", id, api.ErrNotFound) } if len(getParentResp.Contexts) == 0 { - return nil, fmt.Errorf("no ServingEnvironments found for InferenceService %s", id) + return nil, fmt.Errorf("no ServingEnvironments found for InferenceService %s: %w", id, api.ErrNotFound) } toReturn, err := serv.mapper.MapToServingEnvironment(getParentResp.Contexts[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return toReturn, nil @@ -1142,7 +1142,7 @@ func (serv *ModelRegistryService) GetInferenceServiceById(id string) (*openapi.I idAsInt, err := converter.StringToInt64(&id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ @@ -1153,16 +1153,16 @@ func (serv *ModelRegistryService) GetInferenceServiceById(id string) (*openapi.I } if len(getByIdResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple InferenceServices found for id %s", id) + return nil, fmt.Errorf("multiple InferenceServices found for id %s: %w", id, api.ErrNotFound) } if len(getByIdResp.Contexts) == 0 { - return nil, fmt.Errorf("no InferenceService found for id %s", id) + return nil, fmt.Errorf("no InferenceService found for id %s: %w", id, api.ErrNotFound) } toReturn, err := serv.mapper.MapToInferenceService(getByIdResp.Contexts[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return toReturn, nil @@ -1177,7 +1177,7 @@ func (serv *ModelRegistryService) GetInferenceServiceByParams(name *string, serv } else if externalId != nil { filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) } else { - return nil, fmt.Errorf("invalid parameters call, supply either (name and servingEnvironmentId), or externalId") + return nil, fmt.Errorf("invalid parameters call, supply either (name and servingEnvironmentId), or externalId: %w", api.ErrBadRequest) } getByParamsResp, err := serv.mlmdClient.GetContextsByType(context.Background(), &proto.GetContextsByTypeRequest{ @@ -1191,16 +1191,16 @@ func (serv *ModelRegistryService) GetInferenceServiceByParams(name *string, serv } if len(getByParamsResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple inference services found for name=%v, servingEnvironmentId=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(servingEnvironmentId), apiutils.ZeroIfNil(externalId)) + return nil, fmt.Errorf("multiple inference services found for name=%v, servingEnvironmentId=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(servingEnvironmentId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) } if len(getByParamsResp.Contexts) == 0 { - return nil, fmt.Errorf("no inference services found for name=%v, servingEnvironmentId=%v, externalId=%v", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(servingEnvironmentId), apiutils.ZeroIfNil(externalId)) + return nil, fmt.Errorf("no inference services found for name=%v, servingEnvironmentId=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(servingEnvironmentId), apiutils.ZeroIfNil(externalId), api.ErrNotFound) } toReturn, err := serv.mapper.MapToInferenceService(getByParamsResp.Contexts[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return toReturn, nil } @@ -1209,7 +1209,7 @@ func (serv *ModelRegistryService) GetInferenceServiceByParams(name *string, serv func (serv *ModelRegistryService) GetInferenceServices(listOptions api.ListOptions, servingEnvironmentId *string, runtime *string) (*openapi.InferenceServiceList, error) { listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } queries := []string{} @@ -1238,7 +1238,7 @@ func (serv *ModelRegistryService) GetInferenceServices(listOptions api.ListOptio for _, c := range contextsResp.Contexts { mapped, err := serv.mapper.MapToInferenceService(c) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } results = append(results, *mapped) } @@ -1264,7 +1264,7 @@ func (serv *ModelRegistryService) UpsertServeModel(serveModel *openapi.ServeMode // create glog.Info("Creating new ServeModel") if inferenceServiceId == nil { - return nil, fmt.Errorf("missing inferenceServiceId, cannot create ServeModel without parent resource InferenceService") + return nil, fmt.Errorf("missing inferenceServiceId, cannot create ServeModel without parent resource InferenceService: %w", api.ErrBadRequest) } _, err = serv.GetInferenceServiceById(*inferenceServiceId) if err != nil { @@ -1304,7 +1304,7 @@ func (serv *ModelRegistryService) UpsertServeModel(serveModel *openapi.ServeMode execution, err := serv.mapper.MapFromServeModel(serveModel, *inferenceServiceId) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } executionsResp, err := serv.mlmdClient.PutExecutions(context.Background(), &proto.PutExecutionsRequest{ @@ -1318,7 +1318,7 @@ func (serv *ModelRegistryService) UpsertServeModel(serveModel *openapi.ServeMode if inferenceServiceId != nil && serveModel.Id == nil { inferenceServiceId, err := converter.StringToInt64(inferenceServiceId) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } associations := []*proto.Association{} for _, a := range executionsResp.ExecutionIds { @@ -1350,7 +1350,7 @@ func (serv *ModelRegistryService) getInferenceServiceByServeModel(id string) (*o idAsInt, err := converter.StringToInt64(&id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } getParentResp, err := serv.mlmdClient.GetContextsByExecution(context.Background(), &proto.GetContextsByExecutionRequest{ @@ -1361,16 +1361,16 @@ func (serv *ModelRegistryService) getInferenceServiceByServeModel(id string) (*o } if len(getParentResp.Contexts) > 1 { - return nil, fmt.Errorf("multiple InferenceService found for ServeModel %s", id) + return nil, fmt.Errorf("multiple InferenceService found for ServeModel %s: %w", id, api.ErrNotFound) } if len(getParentResp.Contexts) == 0 { - return nil, fmt.Errorf("no InferenceService found for ServeModel %s", id) + return nil, fmt.Errorf("no InferenceService found for ServeModel %s: %w", id, api.ErrNotFound) } toReturn, err := serv.mapper.MapToInferenceService(getParentResp.Contexts[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return toReturn, nil @@ -1380,7 +1380,7 @@ func (serv *ModelRegistryService) getInferenceServiceByServeModel(id string) (*o func (serv *ModelRegistryService) GetServeModelById(id string) (*openapi.ServeModel, error) { idAsInt, err := converter.StringToInt64(&id) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } executionsResp, err := serv.mlmdClient.GetExecutionsByID(context.Background(), &proto.GetExecutionsByIDRequest{ @@ -1391,16 +1391,16 @@ func (serv *ModelRegistryService) GetServeModelById(id string) (*openapi.ServeMo } if len(executionsResp.Executions) > 1 { - return nil, fmt.Errorf("multiple ServeModels found for id %s", id) + return nil, fmt.Errorf("multiple ServeModels found for id %s: %w", id, api.ErrNotFound) } if len(executionsResp.Executions) == 0 { - return nil, fmt.Errorf("no ServeModel found for id %s", id) + return nil, fmt.Errorf("no ServeModel found for id %s: %w", id, api.ErrNotFound) } result, err := serv.mapper.MapToServeModel(executionsResp.Executions[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } return result, nil @@ -1410,7 +1410,7 @@ func (serv *ModelRegistryService) GetServeModelById(id string) (*openapi.ServeMo func (serv *ModelRegistryService) GetServeModels(listOptions api.ListOptions, inferenceServiceId *string) (*openapi.ServeModelList, error) { listOperationOptions, err := apiutils.BuildListOperationOptions(listOptions) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } var executions []*proto.Execution @@ -1418,7 +1418,7 @@ func (serv *ModelRegistryService) GetServeModels(listOptions api.ListOptions, in if inferenceServiceId != nil { ctxId, err := converter.StringToInt64(inferenceServiceId) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } executionsResp, err := serv.mlmdClient.GetExecutionsByContext(context.Background(), &proto.GetExecutionsByContextRequest{ ContextId: ctxId, @@ -1445,7 +1445,7 @@ func (serv *ModelRegistryService) GetServeModels(listOptions api.ListOptions, in for _, a := range executions { mapped, err := serv.mapper.MapToServeModel(a) if err != nil { - return nil, err + return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest) } results = append(results, *mapped) } diff --git a/pkg/core/core_test.go b/pkg/core/core_test.go index 61a38a70..89567511 100644 --- a/pkg/core/core_test.go +++ b/pkg/core/core_test.go @@ -763,7 +763,7 @@ func (suite *CoreTestSuite) TestGetRegisteredModelByParamsWithNoResults() { _, err := service.GetRegisteredModelByParams(apiutils.Of("not-present"), nil) suite.NotNil(err) - suite.Equal("no registered models found for name=not-present, externalId=", err.Error()) + suite.Equal("no registered models found for name=not-present, externalId=: not found", err.Error()) } func (suite *CoreTestSuite) TestGetRegisteredModelByParamsName() { @@ -819,7 +819,7 @@ func (suite *CoreTestSuite) TestGetRegisteredModelByEmptyParams() { _, err = service.GetRegisteredModelByParams(nil, nil) suite.NotNil(err) - suite.Equal("invalid parameters call, supply either name or externalId", err.Error()) + suite.Equal("invalid parameters call, supply either name or externalId: bad request", err.Error()) } func (suite *CoreTestSuite) TestGetRegisteredModelsOrderedById() { @@ -1044,11 +1044,11 @@ func (suite *CoreTestSuite) TestCreateModelVersionFailure() { _, err := service.UpsertModelVersion(modelVersion, nil) suite.NotNil(err) - suite.Equal("missing registered model id, cannot create model version without registered model", err.Error()) + suite.Equal("missing registered model id, cannot create model version without registered model: bad request", err.Error()) _, err = service.UpsertModelVersion(modelVersion, ®isteredModelId) suite.NotNil(err) - suite.Equal("no registered model found for id 9999", err.Error()) + suite.Equal("no registered model found for id 9999: not found", err.Error()) } func (suite *CoreTestSuite) TestUpdateModelVersion() { @@ -1157,7 +1157,7 @@ func (suite *CoreTestSuite) TestUpdateModelVersionFailure() { createdVersion.Id = &wrongId _, err = service.UpsertModelVersion(createdVersion, ®isteredModelId) suite.NotNil(err) - suite.Equal(fmt.Sprintf("no model version found for id %s", wrongId), err.Error()) + suite.Equal(fmt.Sprintf("no model version found for id %s: not found", wrongId), err.Error()) } func (suite *CoreTestSuite) TestGetModelVersionById() { @@ -1206,7 +1206,7 @@ func (suite *CoreTestSuite) TestGetModelVersionByParamsWithNoResults() { _, err := service.GetModelVersionByParams(apiutils.Of("not-present"), ®isteredModelId, nil) suite.NotNil(err) - suite.Equal("no model versions found for versionName=not-present, registeredModelId=1, externalId=", err.Error()) + suite.Equal("no model versions found for versionName=not-present, registeredModelId=1, externalId=: not found", err.Error()) } func (suite *CoreTestSuite) TestGetModelVersionByParamsName() { @@ -1297,7 +1297,7 @@ func (suite *CoreTestSuite) TestGetModelVersionByEmptyParams() { _, err = service.GetModelVersionByParams(nil, nil, nil) suite.NotNil(err) - suite.Equal("invalid parameters call, supply either (versionName and registeredModelId), or externalId", err.Error()) + suite.Equal("invalid parameters call, supply either (versionName and registeredModelId), or externalId: bad request", err.Error()) } func (suite *CoreTestSuite) TestGetModelVersions() { @@ -1451,11 +1451,11 @@ func (suite *CoreTestSuite) TestCreateArtifactFailure() { _, err := service.UpsertArtifact(&artifact, nil) suite.NotNil(err) - suite.Equal("missing model version id, cannot create artifact without model version", err.Error()) + suite.Equal("missing model version id, cannot create artifact without model version: bad request", err.Error()) _, err = service.UpsertArtifact(&artifact, &modelVersionId) suite.NotNil(err) - suite.Equal("no model version found for id 9998", err.Error()) + suite.Equal("no model version found for id 9998: not found", err.Error()) } func (suite *CoreTestSuite) TestUpdateArtifact() { @@ -1529,7 +1529,7 @@ func (suite *CoreTestSuite) TestUpdateArtifactFailure() { updatedArtifact.DocArtifact.Id = &wrongId _, err = service.UpsertArtifact(updatedArtifact, &modelVersionId) suite.NotNil(err) - suite.Equal(fmt.Sprintf("no artifact found for id %s", wrongId), err.Error()) + suite.Equal(fmt.Sprintf("no artifact found for id %s: not found", wrongId), err.Error()) } func (suite *CoreTestSuite) TestGetArtifactById() { @@ -1685,11 +1685,11 @@ func (suite *CoreTestSuite) TestCreateModelArtifactFailure() { _, err := service.UpsertModelArtifact(modelArtifact, nil) suite.NotNil(err) - suite.Equal("missing model version id, cannot create artifact without model version", err.Error()) + suite.Equal("missing model version id, cannot create artifact without model version: bad request", err.Error()) _, err = service.UpsertModelArtifact(modelArtifact, &modelVersionId) suite.NotNil(err) - suite.Equal("no model version found for id 9998", err.Error()) + suite.Equal("no model version found for id 9998: not found", err.Error()) } func (suite *CoreTestSuite) TestUpdateModelArtifact() { @@ -1863,7 +1863,7 @@ func (suite *CoreTestSuite) TestGetModelArtifactByEmptyParams() { _, err = service.GetModelArtifactByParams(nil, nil, nil) suite.NotNil(err) - suite.Equal("invalid parameters call, supply either (artifactName and modelVersionId), or externalId", err.Error()) + suite.Equal("invalid parameters call, supply either (artifactName and modelVersionId), or externalId: bad request", err.Error()) } func (suite *CoreTestSuite) TestGetModelArtifactByParamsWithNoResults() { @@ -1874,7 +1874,7 @@ func (suite *CoreTestSuite) TestGetModelArtifactByParamsWithNoResults() { _, err := service.GetModelArtifactByParams(apiutils.Of("not-present"), &modelVersionId, nil) suite.NotNil(err) - suite.Equal("no model artifacts found for artifactName=not-present, modelVersionId=2, externalId=", err.Error()) + suite.Equal("no model artifacts found for artifactName=not-present, modelVersionId=2, externalId=: not found", err.Error()) } func (suite *CoreTestSuite) TestGetModelArtifacts() { @@ -2120,7 +2120,7 @@ func (suite *CoreTestSuite) TestGetServingEnvironmentByParamsWithNoResults() { _, err := service.GetServingEnvironmentByParams(apiutils.Of("not-present"), nil) suite.NotNil(err) - suite.Equal("no serving environments found for name=not-present, externalId=", err.Error()) + suite.Equal("no serving environments found for name=not-present, externalId=: not found", err.Error()) } func (suite *CoreTestSuite) TestGetServingEnvironmentByParamsName() { @@ -2176,7 +2176,7 @@ func (suite *CoreTestSuite) TestGetServingEnvironmentByEmptyParams() { _, err = service.GetServingEnvironmentByParams(nil, nil) suite.NotNil(err) - suite.Equal("invalid parameters call, supply either name or externalId", err.Error()) + suite.Equal("invalid parameters call, supply either name or externalId: bad request", err.Error()) } func (suite *CoreTestSuite) TestGetServingEnvironmentsOrderedById() { @@ -2414,14 +2414,14 @@ func (suite *CoreTestSuite) TestCreateInferenceServiceFailure() { _, err := service.UpsertInferenceService(eut) suite.NotNil(err) - suite.Equal("no serving environment found for id 9999", err.Error()) + suite.Equal("no serving environment found for id 9999: not found", err.Error()) parentResourceId := suite.registerServingEnvironment(service, nil, nil) eut.ServingEnvironmentId = parentResourceId _, err = service.UpsertInferenceService(eut) suite.NotNil(err) - suite.Equal("no registered model found for id 9998", err.Error()) + suite.Equal("no registered model found for id 9998: not found", err.Error()) } func (suite *CoreTestSuite) TestUpdateInferenceService() { @@ -2555,7 +2555,7 @@ func (suite *CoreTestSuite) TestUpdateInferenceServiceFailure() { createdEntity.Id = &wrongId _, err = service.UpsertInferenceService(createdEntity) suite.NotNil(err) - suite.Equal(fmt.Sprintf("no InferenceService found for id %s", wrongId), err.Error()) + suite.Equal(fmt.Sprintf("no InferenceService found for id %s: not found", wrongId), err.Error()) } func (suite *CoreTestSuite) TestGetInferenceServiceById() { @@ -2742,7 +2742,7 @@ func (suite *CoreTestSuite) TestGetInferenceServiceByParamsWithNoResults() { _, err := service.GetInferenceServiceByParams(apiutils.Of("not-present"), &parentResourceId, nil) suite.NotNil(err) - suite.Equal("no inference services found for name=not-present, servingEnvironmentId=1, externalId=", err.Error()) + suite.Equal("no inference services found for name=not-present, servingEnvironmentId=1, externalId=: not found", err.Error()) } func (suite *CoreTestSuite) TestGetInferenceServiceByParamsName() { @@ -2858,7 +2858,7 @@ func (suite *CoreTestSuite) TestGetInferenceServiceByEmptyParams() { _, err = service.GetInferenceServiceByParams(nil, nil, nil) suite.NotNil(err) - suite.Equal("invalid parameters call, supply either (name and servingEnvironmentId), or externalId", err.Error()) + suite.Equal("invalid parameters call, supply either (name and servingEnvironmentId), or externalId: bad request", err.Error()) } func (suite *CoreTestSuite) TestGetInferenceServices() { @@ -3067,11 +3067,11 @@ func (suite *CoreTestSuite) TestCreateServeModelFailure() { _, err := service.UpsertServeModel(eut, nil) suite.NotNil(err) - suite.Equal("missing inferenceServiceId, cannot create ServeModel without parent resource InferenceService", err.Error()) + suite.Equal("missing inferenceServiceId, cannot create ServeModel without parent resource InferenceService: bad request", err.Error()) _, err = service.UpsertServeModel(eut, &inferenceServiceId) suite.NotNil(err) - suite.Equal("no model version found for id 9998", err.Error()) + suite.Equal("no model version found for id 9998: not found", err.Error()) } func (suite *CoreTestSuite) TestUpdateServeModel() { @@ -3180,7 +3180,7 @@ func (suite *CoreTestSuite) TestUpdateServeModelFailure() { updatedEntity.Id = &wrongId _, err = service.UpsertServeModel(updatedEntity, &inferenceServiceId) suite.NotNil(err) - suite.Equal(fmt.Sprintf("no ServeModel found for id %s", wrongId), err.Error()) + suite.Equal(fmt.Sprintf("no ServeModel found for id %s: not found", wrongId), err.Error()) } func (suite *CoreTestSuite) TestGetServeModelById() {