Skip to content

Commit

Permalink
feat: revamp local/cloud models (#577)
Browse files Browse the repository at this point in the history
* feat: revamp onboarding local/cloud models

* feat: revamp onboarding local/cloud models

* feat: test llm
  • Loading branch information
paulclindo authored Dec 20, 2024
1 parent 5f1d679 commit a029ae3
Show file tree
Hide file tree
Showing 12 changed files with 718 additions and 127 deletions.
94 changes: 71 additions & 23 deletions apps/shinkai-desktop/src/pages/add-ai.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import {
import { useAddLLMProvider } from '@shinkai_network/shinkai-node-state/v2/mutations/addLLMProvider/useAddLLMProvider';
import {
Button,
ErrorMessage,
Form,
FormControl,
FormField,
Expand All @@ -30,11 +29,13 @@ import {
TextField,
} from '@shinkai_network/shinkai-ui';
import { cn } from '@shinkai_network/shinkai-ui/utils';
import { Loader2 } from 'lucide-react';
import { useEffect, useState } from 'react';
import { useForm } from 'react-hook-form';
import { useNavigate } from 'react-router-dom';
import { toast } from 'sonner';

import { useURLQueryParams } from '../hooks/use-url-query-params';
import { useAuth } from '../store/auth';
import { SubpageLayout } from './layout/simple-layout';

Expand Down Expand Up @@ -99,16 +100,21 @@ const AddAIPage = () => {
const { t } = useTranslation();
const auth = useAuth((state) => state.auth);
const navigate = useNavigate();
const query = useURLQueryParams();

const addAgentForm = useForm<AddAgentFormSchema>({
resolver: zodResolver(addAgentSchema),
defaultValues: addAgentFormDefault,
});
const {
mutateAsync: addLLMProvider,
isPending,
isError,
error,
} = useAddLLMProvider({

const preSelectedAiProvider = query.get('aiProvider') as Models;
useEffect(() => {
if (preSelectedAiProvider) {
addAgentForm.setValue('model', preSelectedAiProvider);
}
}, [addAgentForm, preSelectedAiProvider]);

const { mutateAsync: addLLMProvider, isPending } = useAddLLMProvider({
onSuccess: (_, variables) => {
navigate('/inboxes', {
state: {
Expand All @@ -122,7 +128,6 @@ const AddAIPage = () => {
});
},
});

const {
model: currentModel,
isCustomModel: isCustomModelMode,
Expand Down Expand Up @@ -237,11 +242,37 @@ const AddAIPage = () => {
toolkit_permissions: [],
model,
},
enableTest: false,
});
};
const handleTestAndSave = async (data: AddAgentFormSchema) => {
if (!auth) return;
let model = getModelObject(data.model, data.modelType);
if (isCustomModelMode && data.modelCustom && data.modelTypeCustom) {
model = getModelObject(data.modelCustom, data.modelTypeCustom);
} else if (isCustomModelType && data.modelTypeCustom) {
model = getModelObject(data.model, data.modelTypeCustom);
}
await addLLMProvider({
nodeAddress: auth?.node_address ?? '',
token: auth?.api_v2_key ?? '',
agent: {
allowed_message_senders: [],
api_key: data.apikey,
external_url: data.externalUrl,
full_identity_name: `${auth.shinkai_identity}/${auth.profile}/agent/${data.agentName}`,
id: data.agentName,
perform_locally: false,
storage_bucket_permissions: [],
toolkit_permissions: [],
model,
},
enableTest: true,
});
};

return (
<SubpageLayout title={t('llmProviders.add')}>
<SubpageLayout className="max-w-lg" title={t('llmProviders.add')}>
<Form {...addAgentForm}>
<form
className="space-y-10"
Expand Down Expand Up @@ -282,10 +313,7 @@ const AddAIPage = () => {
<FormLabel>
{t('llmProviders.form.modelProvider')}
</FormLabel>
<Select
defaultValue={field.value}
onValueChange={field.onChange}
>
<Select onValueChange={field.onChange} value={field.value}>
<FormControl>
<SelectTrigger>
<SelectValue placeholder={' '} />
Expand Down Expand Up @@ -426,16 +454,36 @@ const AddAIPage = () => {
/>
</div>

{isError && <ErrorMessage message={error.message} />}

<Button
className="w-full"
disabled={isPending}
isLoading={isPending}
type="submit"
>
{t('llmProviders.add')}
</Button>
{isPending ? (
<div className="flex flex-row items-center justify-center space-x-1">
<Loader2 className="h-5 w-5 shrink-0 animate-spin" />
</div>
) : (
<div className="flex flex-col items-center gap-4">
<Button
className="w-full"
disabled={isPending}
isLoading={isPending}
size="sm"
type="submit"
variant="outline"
>
{t('llmProviders.add')}
</Button>
<Button
className="w-full"
disabled={isPending}
isLoading={isPending}
onClick={() => {
addAgentForm.handleSubmit(handleTestAndSave)();
}}
size="sm"
type="button"
>
Test & Add AI
</Button>
</div>
)}
</form>
</Form>
</SubpageLayout>
Expand Down
Loading

0 comments on commit a029ae3

Please sign in to comment.