Skip to content

Commit

Permalink
feat: test llm
Browse files Browse the repository at this point in the history
  • Loading branch information
paulclindo committed Dec 20, 2024
1 parent f28b2b8 commit 3a7fcd3
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 21 deletions.
77 changes: 59 additions & 18 deletions apps/shinkai-desktop/src/pages/add-ai.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import {
import { useAddLLMProvider } from '@shinkai_network/shinkai-node-state/v2/mutations/addLLMProvider/useAddLLMProvider';
import {
Button,
ErrorMessage,
Form,
FormControl,
FormField,
Expand All @@ -30,6 +29,7 @@ import {
TextField,
} from '@shinkai_network/shinkai-ui';
import { cn } from '@shinkai_network/shinkai-ui/utils';
import { Loader2 } from 'lucide-react';
import { useEffect, useState } from 'react';
import { useForm } from 'react-hook-form';
import { useNavigate } from 'react-router-dom';
Expand Down Expand Up @@ -114,12 +114,7 @@ const AddAIPage = () => {
}
}, [addAgentForm, preSelectedAiProvider]);

const {
mutateAsync: addLLMProvider,
isPending,
isError,
error,
} = useAddLLMProvider({
const { mutateAsync: addLLMProvider, isPending } = useAddLLMProvider({
onSuccess: (_, variables) => {
navigate('/inboxes', {
state: {
Expand Down Expand Up @@ -247,11 +242,37 @@ const AddAIPage = () => {
toolkit_permissions: [],
model,
},
enableTest: false,
});
};
const handleTestAndSave = async (data: AddAgentFormSchema) => {
if (!auth) return;
let model = getModelObject(data.model, data.modelType);
if (isCustomModelMode && data.modelCustom && data.modelTypeCustom) {
model = getModelObject(data.modelCustom, data.modelTypeCustom);
} else if (isCustomModelType && data.modelTypeCustom) {
model = getModelObject(data.model, data.modelTypeCustom);
}
await addLLMProvider({
nodeAddress: auth?.node_address ?? '',
token: auth?.api_v2_key ?? '',
agent: {
allowed_message_senders: [],
api_key: data.apikey,
external_url: data.externalUrl,
full_identity_name: `${auth.shinkai_identity}/${auth.profile}/agent/${data.agentName}`,
id: data.agentName,
perform_locally: false,
storage_bucket_permissions: [],
toolkit_permissions: [],
model,
},
enableTest: true,
});
};

return (
<SubpageLayout title={t('llmProviders.add')}>
<SubpageLayout className="max-w-lg" title={t('llmProviders.add')}>
<Form {...addAgentForm}>
<form
className="space-y-10"
Expand Down Expand Up @@ -433,16 +454,36 @@ const AddAIPage = () => {
/>
</div>

{isError && <ErrorMessage message={error.message} />}

<Button
className="w-full"
disabled={isPending}
isLoading={isPending}
type="submit"
>
{t('llmProviders.add')}
</Button>
{isPending ? (
<div className="flex flex-row items-center justify-center space-x-1">
<Loader2 className="h-5 w-5 shrink-0 animate-spin" />
</div>
) : (
<div className="flex flex-col items-center gap-4">
<Button
className="w-full"
disabled={isPending}
isLoading={isPending}
size="sm"
type="submit"
variant="outline"
>
{t('llmProviders.add')}
</Button>
<Button
className="w-full"
disabled={isPending}
isLoading={isPending}
onClick={() => {
addAgentForm.handleSubmit(handleTestAndSave)();
}}
size="sm"
type="button"
>
Test & Add AI
</Button>
</div>
)}
</form>
</Form>
</SubpageLayout>
Expand Down
4 changes: 2 additions & 2 deletions apps/shinkai-desktop/src/pages/ais.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ const AIsPage = () => {

const onAddAgentClick = () => {
if (isLocalShinkaiNodeIsUse) {
navigate('/local-ais');
return;
}
navigate('/add-ai');
navigate('/local-ais');
// navigate('/add-ai');
};

return (
Expand Down
15 changes: 15 additions & 0 deletions libs/shinkai-message-ts/src/api/jobs/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,21 @@ export const addLLMProvider = async (
);
return response.data as AddLLMProviderResponse;
};
export const testLLMProvider = async (
nodeAddress: string,
bearerToken: string,
payload: AddLLMProviderRequest,
) => {
const response = await httpClient.post(
urlJoin(nodeAddress, '/v2/test_llm_provider'),
{ ...payload, model: getModelString(payload.model) },
{
headers: { Authorization: `Bearer ${bearerToken}` },
responseType: 'json',
},
);
return response.data as AddLLMProviderResponse;
};

export const updateLLMProvider = async (
nodeAddress: string,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import { addLLMProvider as addLLMProviderAPI } from '@shinkai_network/shinkai-message-ts/api/jobs/index';
import {
addLLMProvider as addLLMProviderAPI,
testLLMProvider,
} from '@shinkai_network/shinkai-message-ts/api/jobs/index';

import { AddLLMProviderInput } from './types';

export const addLLMProvider = async ({
nodeAddress,
token,
agent,
enableTest,
}: AddLLMProviderInput) => {
if (!agent.model.Ollama && enableTest) {
await testLLMProvider(nodeAddress, token, agent);
}
const data = await addLLMProviderAPI(nodeAddress, token, agent);
return data;
};
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ import {
export type AddLLMProviderInput = Token & {
nodeAddress: string;
agent: SerializedLLMProvider;
enableTest?: boolean;
};
export type AddLLMProviderOutput = AddLLMProviderResponse;
19 changes: 19 additions & 0 deletions libs/shinkai-node-state/src/v2/mutations/testLLMProvider/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import {
addLLMProvider as addLLMProviderAPI,
testLLMProvider,
} from '@shinkai_network/shinkai-message-ts/api/jobs/index';

import { AddLLMProviderInput } from './types';

export const addLLMProvider = async ({
nodeAddress,
token,
agent,
}: AddLLMProviderInput) => {
if (!agent.model.Ollama) {
await testLLMProvider(nodeAddress, token, agent);
}

const data = await addLLMProviderAPI(nodeAddress, token, agent);
return data;
};
11 changes: 11 additions & 0 deletions libs/shinkai-node-state/src/v2/mutations/testLLMProvider/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { Token } from '@shinkai_network/shinkai-message-ts/api/general/types';
import {
AddLLMProviderResponse,
SerializedLLMProvider,
} from '@shinkai_network/shinkai-message-ts/api/jobs/types';

export type AddLLMProviderInput = Token & {
nodeAddress: string;
agent: SerializedLLMProvider;
};
export type AddLLMProviderOutput = AddLLMProviderResponse;
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import type { UseMutationOptions } from '@tanstack/react-query';
import { useMutation } from '@tanstack/react-query';

import { APIError } from '../../types';
import { addLLMProvider } from '.';
import { AddLLMProviderInput, AddLLMProviderOutput } from './types';

type Options = UseMutationOptions<
AddLLMProviderOutput,
APIError,
AddLLMProviderInput
>;

export const useAddLLMProvider = (options?: Options) => {
return useMutation({
mutationFn: addLLMProvider,
...options,
});
};

0 comments on commit 3a7fcd3

Please sign in to comment.