From f33b037cbd5b4ec9e1eff62141386d57c68773d0 Mon Sep 17 00:00:00 2001 From: lucferbux Date: Mon, 9 Sep 2024 10:28:15 +0200 Subject: [PATCH] Set up model registry context and apiHooks Signed-off-by: lucferbux --- .../src/__mocks__/mockModelArtifact.ts | 34 +++ .../src/app/api/__tests__/errorUtils.spec.ts | 2 +- .../src/app/api/__tests__/service.spec.ts | 204 ++++++++++-------- clients/ui/frontend/src/app/api/apiUtils.ts | 2 +- clients/ui/frontend/src/app/api/errorUtils.ts | 2 +- clients/ui/frontend/src/app/api/k8s.ts | 8 +- clients/ui/frontend/src/app/api/service.ts | 179 +++------------ clients/ui/frontend/src/app/api/types.ts | 17 ++ .../ui/frontend/src/app/api/useAPIState.ts | 32 +++ .../src/app/context/ModelRegistryContext.tsx | 44 ++++ .../context/ModelRegistrySelectorContext.tsx | 58 +++++ .../app/context/useModelRegistryAPIState.tsx | 54 +++++ .../useModelArtifactsByVersionId.spec.ts | 88 ++++++++ .../app/hooks/useModelArtifactsByVersionId.ts | 31 +++ .../src/app/hooks/useModelRegistries.ts | 16 ++ .../src/app/hooks/useModelRegistryAPI.ts | 16 ++ .../src/app/hooks/useModelVersionById.ts | 30 +++ .../useModelVersionsByRegisteredModel.ts | 36 ++++ .../src/app/hooks/useRegisteredModelById.ts | 30 +++ .../src/app/hooks/useRegisteredModels.ts | 28 +++ clients/ui/frontend/src/app/types.ts | 4 +- clients/ui/frontend/src/types.ts | 11 - .../frontend/src/utilities/useFetchState.ts | 2 +- 23 files changed, 678 insertions(+), 250 deletions(-) create mode 100644 clients/ui/frontend/src/__mocks__/mockModelArtifact.ts create mode 100644 clients/ui/frontend/src/app/api/types.ts create mode 100644 clients/ui/frontend/src/app/api/useAPIState.ts create mode 100644 clients/ui/frontend/src/app/context/ModelRegistryContext.tsx create mode 100644 clients/ui/frontend/src/app/context/ModelRegistrySelectorContext.tsx create mode 100644 clients/ui/frontend/src/app/context/useModelRegistryAPIState.tsx create mode 100644 clients/ui/frontend/src/app/hooks/__tests__/useModelArtifactsByVersionId.spec.ts create mode 100644 clients/ui/frontend/src/app/hooks/useModelArtifactsByVersionId.ts create mode 100644 clients/ui/frontend/src/app/hooks/useModelRegistries.ts create mode 100644 clients/ui/frontend/src/app/hooks/useModelRegistryAPI.ts create mode 100644 clients/ui/frontend/src/app/hooks/useModelVersionById.ts create mode 100644 clients/ui/frontend/src/app/hooks/useModelVersionsByRegisteredModel.ts create mode 100644 clients/ui/frontend/src/app/hooks/useRegisteredModelById.ts create mode 100644 clients/ui/frontend/src/app/hooks/useRegisteredModels.ts diff --git a/clients/ui/frontend/src/__mocks__/mockModelArtifact.ts b/clients/ui/frontend/src/__mocks__/mockModelArtifact.ts new file mode 100644 index 000000000..8f2bb628b --- /dev/null +++ b/clients/ui/frontend/src/__mocks__/mockModelArtifact.ts @@ -0,0 +1,34 @@ +import { ModelArtifact, ModelArtifactState } from '~/app/types'; + +type MockModelArtifact = { + id?: string; + name?: string; + uri?: string; + state?: ModelArtifactState; + author?: string; +}; + +export const mockModelArtifact = ({ + id = '1', + name = 'test', + uri = 'test', + state = ModelArtifactState.LIVE, + author = 'Author 1', +}: MockModelArtifact): ModelArtifact => ({ + id, + name, + externalID: '1234132asdfasdf', + description: '', + createTimeSinceEpoch: '1710404288975', + lastUpdateTimeSinceEpoch: '1710404288975', + customProperties: {}, + uri, + state, + author, + modelFormatName: 'test', + storageKey: 'test', + storagePath: 'test', + modelFormatVersion: 'test', + serviceAccountName: 'test', + artifactType: 'test', +}); diff --git a/clients/ui/frontend/src/app/api/__tests__/errorUtils.spec.ts b/clients/ui/frontend/src/app/api/__tests__/errorUtils.spec.ts index 3c225152a..57244560a 100644 --- a/clients/ui/frontend/src/app/api/__tests__/errorUtils.spec.ts +++ b/clients/ui/frontend/src/app/api/__tests__/errorUtils.spec.ts @@ -1,5 +1,5 @@ import { NotReadyError } from '~/utilities/useFetchState'; -import { APIError } from '~/types'; +import { APIError } from '~/app/api/types'; import { handleRestFailures } from '~/app/api/errorUtils'; import { mockRegisteredModel } from '~/__mocks__/mockRegisteredModel'; diff --git a/clients/ui/frontend/src/app/api/__tests__/service.spec.ts b/clients/ui/frontend/src/app/api/__tests__/service.spec.ts index 1e2a36e23..6bfe6e6a7 100644 --- a/clients/ui/frontend/src/app/api/__tests__/service.spec.ts +++ b/clients/ui/frontend/src/app/api/__tests__/service.spec.ts @@ -45,18 +45,21 @@ const K8sAPIOptionsMock = {}; describe('createRegisteredModel', () => { it('should call restCREATE and handleRestFailures to create registered model', () => { expect( - createRegisteredModel('hostPath', 'model-registry-1')(K8sAPIOptionsMock, { - description: 'test', - externalID: '1', - name: 'test new registered model', - state: ModelState.LIVE, - customProperties: {}, - }), + createRegisteredModel(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + K8sAPIOptionsMock, + { + description: 'test', + externalID: '1', + name: 'test new registered model', + state: ModelState.LIVE, + customProperties: {}, + }, + ), ).toBe(mockResultPromise); expect(restCREATEMock).toHaveBeenCalledTimes(1); expect(restCREATEMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/registered_models`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/registered_models`, { description: 'test', externalID: '1', @@ -75,20 +78,23 @@ describe('createRegisteredModel', () => { describe('createModelVersion', () => { it('should call restCREATE and handleRestFailures to create model version', () => { expect( - createModelVersion('hostPath', 'model-registry-1')(K8sAPIOptionsMock, { - description: 'test', - externalID: '1', - author: 'test author', - registeredModelId: '1', - name: 'test new model version', - state: ModelState.LIVE, - customProperties: {}, - }), + createModelVersion(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + K8sAPIOptionsMock, + { + description: 'test', + externalID: '1', + author: 'test author', + registeredModelId: '1', + name: 'test new model version', + state: ModelState.LIVE, + customProperties: {}, + }, + ), ).toBe(mockResultPromise); expect(restCREATEMock).toHaveBeenCalledTimes(1); expect(restCREATEMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_versions`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_versions`, { description: 'test', externalID: '1', @@ -109,7 +115,9 @@ describe('createModelVersion', () => { describe('createModelVersionForRegisteredModel', () => { it('should call restCREATE and handleRestFailures to create model version for a model', () => { expect( - createModelVersionForRegisteredModel('hostPath', 'model-registry-1')(K8sAPIOptionsMock, '1', { + createModelVersionForRegisteredModel( + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + )(K8sAPIOptionsMock, '1', { description: 'test', externalID: '1', author: 'test author', @@ -121,8 +129,8 @@ describe('createModelVersionForRegisteredModel', () => { ).toBe(mockResultPromise); expect(restCREATEMock).toHaveBeenCalledTimes(1); expect(restCREATEMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/registered_models/1/versions`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/registered_models/1/versions`, { description: 'test', externalID: '1', @@ -143,25 +151,28 @@ describe('createModelVersionForRegisteredModel', () => { describe('createModelArtifact', () => { it('should call restCREATE and handleRestFailures to create model artifact', () => { expect( - createModelArtifact('hostPath', 'model-registry-1')(K8sAPIOptionsMock, { - description: 'test', - externalID: 'test', - uri: 'test-uri', - state: ModelArtifactState.LIVE, - name: 'test-name', - modelFormatName: 'test-modelformatname', - storageKey: 'teststoragekey', - storagePath: 'teststoragePath', - modelFormatVersion: 'testmodelFormatVersion', - serviceAccountName: 'testserviceAccountname', - customProperties: {}, - artifactType: 'model-artifact', - }), + createModelArtifact(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + K8sAPIOptionsMock, + { + description: 'test', + externalID: 'test', + uri: 'test-uri', + state: ModelArtifactState.LIVE, + name: 'test-name', + modelFormatName: 'test-modelformatname', + storageKey: 'teststoragekey', + storagePath: 'teststoragePath', + modelFormatVersion: 'testmodelFormatVersion', + serviceAccountName: 'testserviceAccountname', + customProperties: {}, + artifactType: 'model-artifact', + }, + ), ).toBe(mockResultPromise); expect(restCREATEMock).toHaveBeenCalledTimes(1); expect(restCREATEMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_artifacts`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_artifacts`, { description: 'test', externalID: 'test', @@ -187,7 +198,9 @@ describe('createModelArtifact', () => { describe('createModelArtifactForModelVersion', () => { it('should call restCREATE and handleRestFailures to create model artifact for version', () => { expect( - createModelArtifactForModelVersion('hostPath', 'model-registry-1')(K8sAPIOptionsMock, '2', { + createModelArtifactForModelVersion( + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + )(K8sAPIOptionsMock, '2', { description: 'test', externalID: 'test', uri: 'test-uri', @@ -204,8 +217,8 @@ describe('createModelArtifactForModelVersion', () => { ).toBe(mockResultPromise); expect(restCREATEMock).toHaveBeenCalledTimes(1); expect(restCREATEMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_versions/2/artifacts`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_versions/2/artifacts`, { description: 'test', externalID: 'test', @@ -230,13 +243,16 @@ describe('createModelArtifactForModelVersion', () => { describe('getRegisteredModel', () => { it('should call restGET and handleRestFailures to fetch registered model', () => { - expect(getRegisteredModel('hostPath', 'model-registry-1')(K8sAPIOptionsMock, '1')).toBe( - mockResultPromise, - ); + expect( + getRegisteredModel(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + K8sAPIOptionsMock, + '1', + ), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/registered_models/1`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/registered_models/1`, {}, K8sAPIOptionsMock, ); @@ -247,13 +263,16 @@ describe('getRegisteredModel', () => { describe('getModelVersion', () => { it('should call restGET and handleRestFailures to fetch model version', () => { - expect(getModelVersion('hostPath', 'model-registry-1')(K8sAPIOptionsMock, '1')).toBe( - mockResultPromise, - ); + expect( + getModelVersion(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + K8sAPIOptionsMock, + '1', + ), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_versions/1`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_versions/1`, {}, K8sAPIOptionsMock, ); @@ -264,13 +283,16 @@ describe('getModelVersion', () => { describe('getModelArtifact', () => { it('should call restGET and handleRestFailures to fetch model version', () => { - expect(getModelArtifact('hostPath', 'model-registry-1')(K8sAPIOptionsMock, '1')).toBe( - mockResultPromise, - ); + expect( + getModelArtifact(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + K8sAPIOptionsMock, + '1', + ), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_artifacts/1`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_artifacts/1`, {}, K8sAPIOptionsMock, ); @@ -281,11 +303,13 @@ describe('getModelArtifact', () => { describe('getListRegisteredModels', () => { it('should call restGET and handleRestFailures to list registered models', () => { - expect(getListRegisteredModels('hostPath', 'model-registry-1')({})).toBe(mockResultPromise); + expect( + getListRegisteredModels(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)({}), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/registered_models`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/registered_models`, {}, K8sAPIOptionsMock, ); @@ -296,11 +320,13 @@ describe('getListRegisteredModels', () => { describe('getListModelArtifacts', () => { it('should call restGET and handleRestFailures to list models artifacts', () => { - expect(getListModelArtifacts('hostPath', 'model-registry-1')({})).toBe(mockResultPromise); + expect( + getListModelArtifacts(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)({}), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_artifacts`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_artifacts`, {}, K8sAPIOptionsMock, ); @@ -311,11 +337,13 @@ describe('getListModelArtifacts', () => { describe('getListModelVersions', () => { it('should call restGET and handleRestFailures to list models versions', () => { - expect(getListModelVersions('hostPath', 'model-registry-1')({})).toBe(mockResultPromise); + expect( + getListModelVersions(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)({}), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_versions`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_versions`, {}, K8sAPIOptionsMock, ); @@ -326,13 +354,16 @@ describe('getListModelVersions', () => { describe('getModelVersionsByRegisteredModel', () => { it('should call restGET and handleRestFailures to list models versions by registered model', () => { - expect(getModelVersionsByRegisteredModel('hostPath', 'model-registry-1')({}, '1')).toBe( - mockResultPromise, - ); + expect( + getModelVersionsByRegisteredModel(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + {}, + '1', + ), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/registered_models/1/versions`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/registered_models/1/versions`, {}, K8sAPIOptionsMock, ); @@ -343,13 +374,16 @@ describe('getModelVersionsByRegisteredModel', () => { describe('getModelArtifactsByModelVersion', () => { it('should call restGET and handleRestFailures to list models artifacts by model version', () => { - expect(getModelArtifactsByModelVersion('hostPath', 'model-registry-1')({}, '1')).toBe( - mockResultPromise, - ); + expect( + getModelArtifactsByModelVersion(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( + {}, + '1', + ), + ).toBe(mockResultPromise); expect(restGETMock).toHaveBeenCalledTimes(1); expect(restGETMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_versions/1/artifacts`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_versions/1/artifacts`, {}, K8sAPIOptionsMock, ); @@ -361,7 +395,7 @@ describe('getModelArtifactsByModelVersion', () => { describe('patchRegisteredModel', () => { it('should call restPATCH and handleRestFailures to update registered model', () => { expect( - patchRegisteredModel('hostPath', 'model-registry-1')( + patchRegisteredModel(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( K8sAPIOptionsMock, { description: 'new test' }, '1', @@ -369,8 +403,8 @@ describe('patchRegisteredModel', () => { ).toBe(mockResultPromise); expect(restPATCHMock).toHaveBeenCalledTimes(1); expect(restPATCHMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/registered_models/1`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/registered_models/1`, { description: 'new test' }, K8sAPIOptionsMock, ); @@ -382,7 +416,7 @@ describe('patchRegisteredModel', () => { describe('patchModelVersion', () => { it('should call restPATCH and handleRestFailures to update model version', () => { expect( - patchModelVersion('hostPath', 'model-registry-1')( + patchModelVersion(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( K8sAPIOptionsMock, { description: 'new test' }, '1', @@ -390,8 +424,8 @@ describe('patchModelVersion', () => { ).toBe(mockResultPromise); expect(restPATCHMock).toHaveBeenCalledTimes(1); expect(restPATCHMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_versions/1`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_versions/1`, { description: 'new test' }, K8sAPIOptionsMock, ); @@ -403,7 +437,7 @@ describe('patchModelVersion', () => { describe('patchModelArtifact', () => { it('should call restPATCH and handleRestFailures to update model artifact', () => { expect( - patchModelArtifact('hostPath', 'model-registry-1')( + patchModelArtifact(`/api/${BFF_API_VERSION}/model_registry/model-registry-1/`)( K8sAPIOptionsMock, { description: 'new test' }, '1', @@ -411,8 +445,8 @@ describe('patchModelArtifact', () => { ).toBe(mockResultPromise); expect(restPATCHMock).toHaveBeenCalledTimes(1); expect(restPATCHMock).toHaveBeenCalledWith( - 'hostPath', - `/api/${BFF_API_VERSION}/model_registry/model-registry-1/model_artifacts/1`, + `/api/${BFF_API_VERSION}/model_registry/model-registry-1/`, + `/model_artifacts/1`, { description: 'new test' }, K8sAPIOptionsMock, ); diff --git a/clients/ui/frontend/src/app/api/apiUtils.ts b/clients/ui/frontend/src/app/api/apiUtils.ts index d4adff6c1..69015e5e5 100644 --- a/clients/ui/frontend/src/app/api/apiUtils.ts +++ b/clients/ui/frontend/src/app/api/apiUtils.ts @@ -1,4 +1,4 @@ -import { APIOptions } from '~/types'; +import { APIOptions } from '~/app/api/types'; import { EitherOrNone } from '~/typeHelpers'; export const mergeRequestInit = ( diff --git a/clients/ui/frontend/src/app/api/errorUtils.ts b/clients/ui/frontend/src/app/api/errorUtils.ts index 4cb92823b..59975c726 100644 --- a/clients/ui/frontend/src/app/api/errorUtils.ts +++ b/clients/ui/frontend/src/app/api/errorUtils.ts @@ -1,4 +1,4 @@ -import { APIError } from '~/types'; +import { APIError } from '~/app/api/types'; import { isCommonStateError } from '~/utilities/useFetchState'; const isError = (e: unknown): e is APIError => diff --git a/clients/ui/frontend/src/app/api/k8s.ts b/clients/ui/frontend/src/app/api/k8s.ts index e17e55dbe..5138090da 100644 --- a/clients/ui/frontend/src/app/api/k8s.ts +++ b/clients/ui/frontend/src/app/api/k8s.ts @@ -1,10 +1,10 @@ -import { APIOptions } from '~/types'; +import { APIOptions } from '~/app/api/types'; import { handleRestFailures } from '~/app/api/errorUtils'; import { restGET } from '~/app/api/apiUtils'; -import { ModelRegistry } from '~/app/types'; +import { ModelRegistryList } from '~/app/types'; import { BFF_API_VERSION } from '~/app/const'; -export const getModelRegistries = +export const getListModelRegistries = (hostPath: string) => - (opts: APIOptions): Promise => + (opts: APIOptions): Promise => handleRestFailures(restGET(hostPath, `/api/${BFF_API_VERSION}/model_registry`, {}, opts)); diff --git a/clients/ui/frontend/src/app/api/service.ts b/clients/ui/frontend/src/app/api/service.ts index 42f8dbb56..696c46bad 100644 --- a/clients/ui/frontend/src/app/api/service.ts +++ b/clients/ui/frontend/src/app/api/service.ts @@ -10,218 +10,107 @@ import { RegisteredModel, } from '~/app/types'; import { restCREATE, restGET, restPATCH } from '~/app/api/apiUtils'; -import { APIOptions } from '~/types'; +import { APIOptions } from '~/app/api/types'; import { handleRestFailures } from '~/app/api/errorUtils'; -import { BFF_API_VERSION } from '~/app/const'; export const createRegisteredModel = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, data: CreateRegisteredModelData): Promise => - handleRestFailures( - restCREATE( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/registered_models`, - data, - {}, - opts, - ), - ); + handleRestFailures(restCREATE(hostPath, `/registered_models`, data, {}, opts)); export const createModelVersion = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, data: CreateModelVersionData): Promise => - handleRestFailures( - restCREATE( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_versions`, - data, - {}, - opts, - ), - ); + handleRestFailures(restCREATE(hostPath, `/model_versions`, data, {}, opts)); + export const createModelVersionForRegisteredModel = - (hostPath: string, mrName: string) => + (hostPath: string) => ( opts: APIOptions, registeredModelId: string, data: CreateModelVersionData, ): Promise => handleRestFailures( - restCREATE( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/registered_models/${registeredModelId}/versions`, - data, - {}, - opts, - ), + restCREATE(hostPath, `/registered_models/${registeredModelId}/versions`, data, {}, opts), ); export const createModelArtifact = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, data: CreateModelArtifactData): Promise => - handleRestFailures( - restCREATE( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_artifacts`, - data, - {}, - opts, - ), - ); + handleRestFailures(restCREATE(hostPath, `/model_artifacts`, data, {}, opts)); export const createModelArtifactForModelVersion = - (hostPath: string, mrName: string) => + (hostPath: string) => ( opts: APIOptions, modelVersionId: string, data: CreateModelArtifactData, ): Promise => handleRestFailures( - restCREATE( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_versions/${modelVersionId}/artifacts`, - data, - {}, - opts, - ), + restCREATE(hostPath, `/model_versions/${modelVersionId}/artifacts`, data, {}, opts), ); export const getRegisteredModel = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, registeredModelId: string): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/registered_models/${registeredModelId}`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/registered_models/${registeredModelId}`, {}, opts)); export const getModelVersion = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, modelversionId: string): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_versions/${modelversionId}`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/model_versions/${modelversionId}`, {}, opts)); export const getModelArtifact = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, modelArtifactId: string): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_artifacts/${modelArtifactId}`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/model_artifacts/${modelArtifactId}`, {}, opts)); export const getListModelArtifacts = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_artifacts`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/model_artifacts`, {}, opts)); export const getListModelVersions = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_versions`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/model_versions`, {}, opts)); export const getListRegisteredModels = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/registered_models`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/registered_models`, {}, opts)); export const getModelVersionsByRegisteredModel = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, registeredmodelId: string): Promise => handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/registered_models/${registeredmodelId}/versions`, - {}, - opts, - ), + restGET(hostPath, `/registered_models/${registeredmodelId}/versions`, {}, opts), ); export const getModelArtifactsByModelVersion = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, modelVersionId: string): Promise => - handleRestFailures( - restGET( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_versions/${modelVersionId}/artifacts`, - {}, - opts, - ), - ); + handleRestFailures(restGET(hostPath, `/model_versions/${modelVersionId}/artifacts`, {}, opts)); export const patchRegisteredModel = - (hostPath: string, mrName: string) => + (hostPath: string) => ( opts: APIOptions, data: Partial, registeredModelId: string, ): Promise => - handleRestFailures( - restPATCH( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/registered_models/${registeredModelId}`, - data, - opts, - ), - ); + handleRestFailures(restPATCH(hostPath, `/registered_models/${registeredModelId}`, data, opts)); export const patchModelVersion = - (hostPath: string, mrName: string) => + (hostPath: string) => (opts: APIOptions, data: Partial, modelversionId: string): Promise => - handleRestFailures( - restPATCH( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_versions/${modelversionId}`, - data, - opts, - ), - ); + handleRestFailures(restPATCH(hostPath, `/model_versions/${modelversionId}`, data, opts)); export const patchModelArtifact = - (hostPath: string, mrName: string) => + (hostPath: string) => ( opts: APIOptions, data: Partial, modelartifactId: string, ): Promise => - handleRestFailures( - restPATCH( - hostPath, - `/api/${BFF_API_VERSION}/model_registry/${mrName}/model_artifacts/${modelartifactId}`, - data, - opts, - ), - ); + handleRestFailures(restPATCH(hostPath, `/model_artifacts/${modelartifactId}`, data, opts)); diff --git a/clients/ui/frontend/src/app/api/types.ts b/clients/ui/frontend/src/app/api/types.ts new file mode 100644 index 000000000..e7994e60f --- /dev/null +++ b/clients/ui/frontend/src/app/api/types.ts @@ -0,0 +1,17 @@ +export type APIOptions = { + dryRun?: boolean; + signal?: AbortSignal; + parseJSON?: boolean; +}; + +export type APIError = { + code: string; + message: string; +}; + +export type APIState = { + /** If API will successfully call */ + apiAvailable: boolean; + /** The available API functions */ + api: T; +}; diff --git a/clients/ui/frontend/src/app/api/useAPIState.ts b/clients/ui/frontend/src/app/api/useAPIState.ts new file mode 100644 index 000000000..4783e8cbe --- /dev/null +++ b/clients/ui/frontend/src/app/api/useAPIState.ts @@ -0,0 +1,32 @@ +import * as React from 'react'; +import { APIState } from '~/app/api/types'; + +const useAPIState = ( + hostPath: string | null, + createAPI: (path: string) => T, +): [apiState: APIState, refreshAPIState: () => void] => { + const [internalAPIToggleState, setInternalAPIToggleState] = React.useState(false); + + const refreshAPIState = React.useCallback(() => { + setInternalAPIToggleState((v) => !v); + }, []); + + const apiState = React.useMemo>(() => { + let path = hostPath; + if (!path) { + // TODO: we need to figure out maybe a stopgap or something + path = ''; + } + const api = createAPI(path); + + return { + apiAvailable: !!path, + api, + }; + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [createAPI, hostPath, internalAPIToggleState]); + + return [apiState, refreshAPIState]; +}; + +export default useAPIState; diff --git a/clients/ui/frontend/src/app/context/ModelRegistryContext.tsx b/clients/ui/frontend/src/app/context/ModelRegistryContext.tsx new file mode 100644 index 000000000..6c107e277 --- /dev/null +++ b/clients/ui/frontend/src/app/context/ModelRegistryContext.tsx @@ -0,0 +1,44 @@ +import * as React from 'react'; +import { BFF_API_VERSION } from '~/app/const'; +import useModelRegistryAPIState, { ModelRegistryAPIState } from './useModelRegistryAPIState'; + +export type ModelRegistryContextType = { + apiState: ModelRegistryAPIState; + refreshAPIState: () => void; +}; + +type ModelRegistryContextProviderProps = { + children: React.ReactNode; + modelRegistryName: string; +}; + +export const ModelRegistryContext = React.createContext({ + // eslint-disable-next-line @typescript-eslint/consistent-type-assertions + apiState: { apiAvailable: false, api: null as unknown as ModelRegistryAPIState['api'] }, + refreshAPIState: () => undefined, +}); + +export const ModelRegistryContextProvider: React.FC = ({ + children, + modelRegistryName, +}) => { + const hostPath = modelRegistryName + ? `/api/${BFF_API_VERSION}/model_registry/${modelRegistryName}` + : null; + + const [apiState, refreshAPIState] = useModelRegistryAPIState(hostPath); + + return ( + ({ + apiState, + refreshAPIState, + }), + [apiState, refreshAPIState], + )} + > + {children} + + ); +}; diff --git a/clients/ui/frontend/src/app/context/ModelRegistrySelectorContext.tsx b/clients/ui/frontend/src/app/context/ModelRegistrySelectorContext.tsx new file mode 100644 index 000000000..5ddbfd990 --- /dev/null +++ b/clients/ui/frontend/src/app/context/ModelRegistrySelectorContext.tsx @@ -0,0 +1,58 @@ +import * as React from 'react'; +import { ModelRegistry } from '~/app/types'; +import useModelRegistries from '~/app/hooks/useModelRegistries'; + +export type ModelRegistrySelectorContextType = { + modelRegistriesLoaded: boolean; + modelRegistriesLoadError?: Error; + modelRegistries: ModelRegistry[]; + preferredModelRegistry: ModelRegistry | undefined; + updatePreferredModelRegistry: (modelRegistry: ModelRegistry | undefined) => void; +}; + +type ModelRegistrySelectorContextProviderProps = { + children: React.ReactNode; +}; + +export const ModelRegistrySelectorContext = React.createContext({ + modelRegistriesLoaded: false, + modelRegistriesLoadError: undefined, + modelRegistries: [], + preferredModelRegistry: undefined, + updatePreferredModelRegistry: () => undefined, +}); + +export const ModelRegistrySelectorContextProvider: React.FC< + ModelRegistrySelectorContextProviderProps +> = ({ children, ...props }) => ( + + {children} + +); + +const EnabledModelRegistrySelectorContextProvider: React.FC< + ModelRegistrySelectorContextProviderProps +> = ({ children }) => { + const [modelRegistries, isLoaded, error] = useModelRegistries(); + const [preferredModelRegistry, setPreferredModelRegistry] = + React.useState(undefined); + + const firstModelRegistry = modelRegistries.length > 0 ? modelRegistries[0] : null; + + return ( + ({ + modelRegistriesLoaded: isLoaded, + modelRegistriesLoadError: error, + modelRegistries, + preferredModelRegistry: preferredModelRegistry ?? firstModelRegistry ?? undefined, + updatePreferredModelRegistry: setPreferredModelRegistry, + }), + [isLoaded, error, modelRegistries, preferredModelRegistry, firstModelRegistry], + )} + > + {children} + + ); +}; diff --git a/clients/ui/frontend/src/app/context/useModelRegistryAPIState.tsx b/clients/ui/frontend/src/app/context/useModelRegistryAPIState.tsx new file mode 100644 index 000000000..9b1465ba0 --- /dev/null +++ b/clients/ui/frontend/src/app/context/useModelRegistryAPIState.tsx @@ -0,0 +1,54 @@ +import React from 'react'; +import { APIState } from '~/app/api/types'; +import { ModelRegistryAPIs } from '~/app/types'; +import { + createModelArtifact, + createModelArtifactForModelVersion, + createModelVersion, + createModelVersionForRegisteredModel, + createRegisteredModel, + getListModelArtifacts, + getListModelVersions, + getListRegisteredModels, + getModelArtifact, + getModelArtifactsByModelVersion, + getModelVersion, + getModelVersionsByRegisteredModel, + getRegisteredModel, + patchModelArtifact, + patchModelVersion, + patchRegisteredModel, +} from '~/app/api/service'; +import useAPIState from '~/app/api/useAPIState'; + +export type ModelRegistryAPIState = APIState; + +const useModelRegistryAPIState = ( + hostPath: string | null, +): [apiState: ModelRegistryAPIState, refreshAPIState: () => void] => { + const createAPI = React.useCallback( + (path: string) => ({ + createRegisteredModel: createRegisteredModel(path), + createModelVersion: createModelVersion(path), + createModelVersionForRegisteredModel: createModelVersionForRegisteredModel(path), + createModelArtifact: createModelArtifact(path), + createModelArtifactForModelVersion: createModelArtifactForModelVersion(path), + getRegisteredModel: getRegisteredModel(path), + getModelVersion: getModelVersion(path), + getModelArtifact: getModelArtifact(path), + listModelArtifacts: getListModelArtifacts(path), + listModelVersions: getListModelVersions(path), + listRegisteredModels: getListRegisteredModels(path), + getModelVersionsByRegisteredModel: getModelVersionsByRegisteredModel(path), + getModelArtifactsByModelVersion: getModelArtifactsByModelVersion(path), + patchRegisteredModel: patchRegisteredModel(path), + patchModelVersion: patchModelVersion(path), + patchModelArtifact: patchModelArtifact(path), + }), + [], + ); + + return useAPIState(hostPath, createAPI); +}; + +export default useModelRegistryAPIState; diff --git a/clients/ui/frontend/src/app/hooks/__tests__/useModelArtifactsByVersionId.spec.ts b/clients/ui/frontend/src/app/hooks/__tests__/useModelArtifactsByVersionId.spec.ts new file mode 100644 index 000000000..3656cfafe --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/__tests__/useModelArtifactsByVersionId.spec.ts @@ -0,0 +1,88 @@ +import useModelArtifactsByVersionId from '~/app/hooks/useModelArtifactsByVersionId'; +import { useModelRegistryAPI } from '~/app/hooks/useModelRegistryAPI'; +import { NotReadyError } from '~/utilities/useFetchState'; +import { ModelRegistryAPIs } from '~/app/types'; +import { mockModelArtifact } from '~/__mocks__/mockModelArtifact'; +import { testHook } from '~/__tests__/unit/testUtils/hooks'; + +global.fetch = jest.fn(); +// Mock the useModelRegistryAPI hook +jest.mock('~/app/hooks/useModelRegistryAPI', () => ({ + useModelRegistryAPI: jest.fn(), +})); + +const mockUseModelRegistryAPI = jest.mocked(useModelRegistryAPI); + +const mockModelRegistryAPIs: ModelRegistryAPIs = { + createRegisteredModel: jest.fn(), + createModelVersionForRegisteredModel: jest.fn(), + createModelArtifactForModelVersion: jest.fn(), + getRegisteredModel: jest.fn(), + getModelVersion: jest.fn(), + listRegisteredModels: jest.fn(), + getModelVersionsByRegisteredModel: jest.fn(), + getModelArtifactsByModelVersion: jest.fn(), + patchRegisteredModel: jest.fn(), + patchModelVersion: jest.fn(), +}; + +describe('useModelArtifactsByVersionId', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should return NotReadyError if API is not available', async () => { + mockUseModelRegistryAPI.mockReturnValue({ + api: mockModelRegistryAPIs, + apiAvailable: false, + refreshAllAPI: jest.fn(), + }); + + const { result } = testHook(useModelArtifactsByVersionId)(); + const [, , error] = result.current; + + expect(error).toBe('API not yet available'); + expect(error?.message).toBeInstanceOf(NotReadyError); + }); + + it('should return NotReadyError if modelVersionId is not provided', async () => { + mockUseModelRegistryAPI.mockReturnValue({ + api: mockModelRegistryAPIs, + apiAvailable: true, + refreshAllAPI: jest.fn(), + }); + + const { result } = testHook(useModelArtifactsByVersionId)(); + const [, , error] = result.current; + + expect(error).toBeInstanceOf(NotReadyError); + expect(error?.message).toBe('No model registeredModel id'); + }); + + it('should fetch model artifacts if API is available and modelVersionId is provided', async () => { + const mockedResponse = { + items: [mockModelArtifact({ id: 'artifact-1' })], + size: 1, + pageSize: 1, + }; + const mockGetModelArtifactsByModelVersion = jest.fn().mockResolvedValue(mockedResponse); + + mockUseModelRegistryAPI.mockReturnValue({ + api: { + ...mockModelRegistryAPIs, + getModelArtifactsByModelVersion: mockGetModelArtifactsByModelVersion, + }, + apiAvailable: false, + refreshAllAPI: jest.fn(), + }); + + const { result } = testHook(useModelArtifactsByVersionId)('version-id'); + const [data] = result.current; + + expect(data).toEqual(mockedResponse); + expect(mockGetModelArtifactsByModelVersion).toHaveBeenCalledWith( + expect.any(Object), + 'version-id', + ); + }); +}); diff --git a/clients/ui/frontend/src/app/hooks/useModelArtifactsByVersionId.ts b/clients/ui/frontend/src/app/hooks/useModelArtifactsByVersionId.ts new file mode 100644 index 000000000..9bd973c2e --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useModelArtifactsByVersionId.ts @@ -0,0 +1,31 @@ +import * as React from 'react'; +import useFetchState, { + FetchState, + FetchStateCallbackPromise, + NotReadyError, +} from '~/utilities/useFetchState'; +import { ModelArtifactList } from '~/app/types'; +import { useModelRegistryAPI } from '~/app/hooks/useModelRegistryAPI'; + +const useModelArtifactsByVersionId = (modelVersionId?: string): FetchState => { + const { api, apiAvailable } = useModelRegistryAPI(); + const callback = React.useCallback>( + (opts) => { + if (!apiAvailable) { + return Promise.reject(new NotReadyError('API not yet available')); + } + if (!modelVersionId) { + return Promise.reject(new NotReadyError('No model registeredModel id')); + } + return api.getModelArtifactsByModelVersion(opts, modelVersionId); + }, + [api, apiAvailable, modelVersionId], + ); + return useFetchState( + callback, + { items: [], size: 0, pageSize: 0, nextPageToken: '' }, + { initialPromisePurity: true }, + ); +}; + +export default useModelArtifactsByVersionId; diff --git a/clients/ui/frontend/src/app/hooks/useModelRegistries.ts b/clients/ui/frontend/src/app/hooks/useModelRegistries.ts new file mode 100644 index 000000000..aa482b099 --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useModelRegistries.ts @@ -0,0 +1,16 @@ +import * as React from 'react'; +import { BFF_API_VERSION } from '~/app/const'; +import useFetchState, { FetchState, FetchStateCallbackPromise } from '~/utilities/useFetchState'; +import { ModelRegistryList } from '~/app/types'; +import { getListModelRegistries } from '~/app/api/k8s'; + +const useModelRegistries = (): FetchState => { + const listModelRegistries = getListModelRegistries(`/api/${BFF_API_VERSION}/model_registry`); + const callback = React.useCallback>( + (opts) => listModelRegistries(opts), + [listModelRegistries], + ); + return useFetchState(callback, [], { initialPromisePurity: true }); +}; + +export default useModelRegistries; diff --git a/clients/ui/frontend/src/app/hooks/useModelRegistryAPI.ts b/clients/ui/frontend/src/app/hooks/useModelRegistryAPI.ts new file mode 100644 index 000000000..5a211568d --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useModelRegistryAPI.ts @@ -0,0 +1,16 @@ +import * as React from 'react'; +import { ModelRegistryAPIState } from '~/app/context/useModelRegistryAPIState'; +import { ModelRegistryContext } from '~/app/context/ModelRegistryContext'; + +type UseModelRegistryAPI = ModelRegistryAPIState & { + refreshAllAPI: () => void; +}; + +export const useModelRegistryAPI = (): UseModelRegistryAPI => { + const { apiState, refreshAPIState: refreshAllAPI } = React.useContext(ModelRegistryContext); + + return { + refreshAllAPI, + ...apiState, + }; +}; diff --git a/clients/ui/frontend/src/app/hooks/useModelVersionById.ts b/clients/ui/frontend/src/app/hooks/useModelVersionById.ts new file mode 100644 index 000000000..4a82a1721 --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useModelVersionById.ts @@ -0,0 +1,30 @@ +import * as React from 'react'; +import useFetchState, { + FetchState, + FetchStateCallbackPromise, + NotReadyError, +} from '~/utilities/useFetchState'; +import { ModelVersion } from '~/app/types'; +import { useModelRegistryAPI } from '~/app/hooks/useModelRegistryAPI'; + +const useModelVersionById = (modelVersionId?: string): FetchState => { + const { api, apiAvailable } = useModelRegistryAPI(); + + const call = React.useCallback>( + (opts) => { + if (!apiAvailable) { + return Promise.reject(new NotReadyError('API not yet available')); + } + if (!modelVersionId) { + return Promise.reject(new NotReadyError('No model version id')); + } + + return api.getModelVersion(opts, modelVersionId); + }, + [api, apiAvailable, modelVersionId], + ); + + return useFetchState(call, null); +}; + +export default useModelVersionById; diff --git a/clients/ui/frontend/src/app/hooks/useModelVersionsByRegisteredModel.ts b/clients/ui/frontend/src/app/hooks/useModelVersionsByRegisteredModel.ts new file mode 100644 index 000000000..8e24bbdea --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useModelVersionsByRegisteredModel.ts @@ -0,0 +1,36 @@ +import * as React from 'react'; +import useFetchState, { + FetchState, + FetchStateCallbackPromise, + NotReadyError, +} from '~/utilities/useFetchState'; +import { ModelVersionList } from '~/app/types'; +import { useModelRegistryAPI } from '~/app/hooks/useModelRegistryAPI'; + +const useModelVersionsByRegisteredModel = ( + registeredModelId?: string, +): FetchState => { + const { api, apiAvailable } = useModelRegistryAPI(); + + const call = React.useCallback>( + (opts) => { + if (!apiAvailable) { + return Promise.reject(new NotReadyError('API not yet available')); + } + if (!registeredModelId) { + return Promise.reject(new NotReadyError('No model registeredModel id')); + } + + return api.getModelVersionsByRegisteredModel(opts, registeredModelId); + }, + [api, apiAvailable, registeredModelId], + ); + + return useFetchState( + call, + { items: [], size: 0, pageSize: 0, nextPageToken: '' }, + { initialPromisePurity: true }, + ); +}; + +export default useModelVersionsByRegisteredModel; diff --git a/clients/ui/frontend/src/app/hooks/useRegisteredModelById.ts b/clients/ui/frontend/src/app/hooks/useRegisteredModelById.ts new file mode 100644 index 000000000..2c8ea9d57 --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useRegisteredModelById.ts @@ -0,0 +1,30 @@ +import * as React from 'react'; +import useFetchState, { + FetchState, + FetchStateCallbackPromise, + NotReadyError, +} from '~/utilities/useFetchState'; +import { RegisteredModel } from '~/app/types'; +import { useModelRegistryAPI } from '~/app/hooks/useModelRegistryAPI'; + +const useRegisteredModelById = (registeredModel?: string): FetchState => { + const { api, apiAvailable } = useModelRegistryAPI(); + + const call = React.useCallback>( + (opts) => { + if (!apiAvailable) { + return Promise.reject(new NotReadyError('API not yet available')); + } + if (!registeredModel) { + return Promise.reject(new NotReadyError('No registered model id')); + } + + return api.getRegisteredModel(opts, registeredModel); + }, + [api, apiAvailable, registeredModel], + ); + + return useFetchState(call, null); +}; + +export default useRegisteredModelById; diff --git a/clients/ui/frontend/src/app/hooks/useRegisteredModels.ts b/clients/ui/frontend/src/app/hooks/useRegisteredModels.ts new file mode 100644 index 000000000..36766cf53 --- /dev/null +++ b/clients/ui/frontend/src/app/hooks/useRegisteredModels.ts @@ -0,0 +1,28 @@ +import * as React from 'react'; +import useFetchState, { + FetchState, + FetchStateCallbackPromise, + NotReadyError, +} from '~/utilities/useFetchState'; +import { RegisteredModelList } from '~/app/types'; +import { useModelRegistryAPI } from '~/app/hooks/useModelRegistryAPI'; + +const useRegisteredModels = (): FetchState => { + const { api, apiAvailable } = useModelRegistryAPI(); + const callback = React.useCallback>( + (opts) => { + if (!apiAvailable) { + return Promise.reject(new NotReadyError('API not yet available')); + } + return api.listRegisteredModels(opts).then((r) => r); + }, + [api, apiAvailable], + ); + return useFetchState( + callback, + { items: [], size: 0, pageSize: 0, nextPageToken: '' }, + { initialPromisePurity: true }, + ); +}; + +export default useRegisteredModels; diff --git a/clients/ui/frontend/src/app/types.ts b/clients/ui/frontend/src/app/types.ts index 17fcc5890..912c845e1 100644 --- a/clients/ui/frontend/src/app/types.ts +++ b/clients/ui/frontend/src/app/types.ts @@ -1,4 +1,4 @@ -import { APIOptions } from '~/types'; +import { APIOptions } from '~/app/api/types'; export enum ModelState { LIVE = 'LIVE', @@ -21,6 +21,8 @@ export type ModelRegistry = { description: string; }; +export type ModelRegistryList = ModelRegistry[]; + export enum ModelRegistryMetadataType { INT = 'MetadataIntValue', DOUBLE = 'MetadataDoubleValue', diff --git a/clients/ui/frontend/src/types.ts b/clients/ui/frontend/src/types.ts index 0be5cb1a9..34f4c36fc 100644 --- a/clients/ui/frontend/src/types.ts +++ b/clients/ui/frontend/src/types.ts @@ -19,14 +19,3 @@ export type CommonConfig = { export type FeatureFlag = { modelRegistry: boolean; }; - -export type APIOptions = { - dryRun?: boolean; - signal?: AbortSignal; - parseJSON?: boolean; -}; - -export type APIError = { - code: string; - message: string; -}; diff --git a/clients/ui/frontend/src/utilities/useFetchState.ts b/clients/ui/frontend/src/utilities/useFetchState.ts index 64b2e3eb3..aa688d349 100644 --- a/clients/ui/frontend/src/utilities/useFetchState.ts +++ b/clients/ui/frontend/src/utilities/useFetchState.ts @@ -1,5 +1,5 @@ import * as React from 'react'; -import { APIOptions } from '~/types'; +import { APIOptions } from '~/app/api/types'; /** * Allows "I'm not ready" rejections if you lack a lazy provided prop