diff --git a/app/src/components/content/index.tsx b/app/src/components/content/index.tsx index 84df4f33..d1606401 100644 --- a/app/src/components/content/index.tsx +++ b/app/src/components/content/index.tsx @@ -25,6 +25,14 @@ type Chat = { answer?: string; loading?: string; }; +type Model = { + id: string; + name: string; + vendor: string; + version: string; + capabilities: Array; + timeCreated: string; +}; const defaultServiceType: string = localStorage.getItem("service") || "text"; const defaultBackendType: string = localStorage.getItem("backend") || "java"; @@ -46,6 +54,7 @@ const Content = () => { const question = useRef(); const chatData = useRef>([]); const socket = useRef(); + const finetune = useRef(false); const [client, setClient] = useState(null); const messagesDP = useRef( @@ -167,7 +176,13 @@ const Content = () => { JSON.stringify({ msgType: "question", data: question.current }) ); } else { - sendPrompt(client, question.current!, modelId!, conversationId!); + sendPrompt( + client, + question.current!, + modelId!, + conversationId!, + finetune.current + ); } } }; @@ -199,9 +214,9 @@ const Content = () => { localStorage.setItem("backend", backend); location.reload(); }; - const modelIdChangeHandler = (event: CustomEvent) => { - console.log("model Id: ", event.detail.value); - if (event.detail.value != null) setModelId(event.detail.value); + const modelIdChangeHandler = (value: string, modelType: boolean) => { + if (value != null) setModelId(value); + finetune.current = modelType; }; const clearSummary = () => { setSummaryResults(""); diff --git a/app/src/components/content/settings.tsx b/app/src/components/content/settings.tsx index aeeee079..269f66d3 100644 --- a/app/src/components/content/settings.tsx +++ b/app/src/components/content/settings.tsx @@ -5,6 +5,7 @@ import "oj-c/select-single"; import "ojs/ojlistitemlayout"; import "ojs/ojhighlighttext"; import MutableArrayDataProvider = require("ojs/ojmutablearraydataprovider"); +import { ojSelectSingle } from "@oracle/oraclejet/ojselectsingle"; type ServiceTypeVal = "text" | "summary" | "sim"; type BackendTypeVal = "java" | "python"; @@ -17,7 +18,7 @@ type Props = { backendType: BackendTypeVal; aiServiceChange: (service: ServiceTypeVal) => void; backendChange: (backend: BackendTypeVal) => void; - modelIdChange: (modelName: any) => void; + modelIdChange: (modelId: any, modelData: any) => void; }; const serviceTypes = [ @@ -30,6 +31,21 @@ const backendTypes = [ { value: "java", label: "Java" }, { value: "python", label: "Python" }, ]; +type Model = { + id: string; + name: string; + vendor: string; + version: string; + capabilities: Array; + timeCreated: string; +}; +type Endpoint = { + id: string; + name: string; + state: string; + model: string; + timeCreated: string; +}; const serviceOptionsDP = new MutableArrayDataProvider< Services["value"], Services @@ -50,8 +66,11 @@ export const Settings = (props: Props) => { }; const modelDP = useRef( - new MutableArrayDataProvider([], { keyAttributes: "id" }) + new MutableArrayDataProvider([], { + keyAttributes: "id", + }) ); + const endpoints = useRef>(); const fetchModels = async () => { try { @@ -60,9 +79,8 @@ export const Settings = (props: Props) => { throw new Error(`Response status: ${response.status}`); } const json = await response.json(); - const result = json.filter((model: any) => { + const result = json.filter((model: Model) => { if ( - // model.capabilities.includes("FINE_TUNE") && model.capabilities.includes("TEXT_GENERATION") && (model.vendor == "cohere" || model.vendor == "") && model.version != "14.2" @@ -77,11 +95,55 @@ export const Settings = (props: Props) => { ); } }; + const fetchEndpoints = async () => { + try { + const response = await fetch("/api/genai/endpoints"); + if (!response.ok) { + throw new Error(`Response status: ${response.status}`); + } + const json = await response.json(); + const result = json.filter((endpoint: Endpoint) => { + // add filtering code here + return endpoint; + }); + endpoints.current = result; + } catch (error: any) { + console.log( + "Java service not available for fetching list of Endpoints: ", + error.message + ); + } + }; useEffect(() => { + fetchEndpoints(); fetchModels(); }, []); + const modelChangeHandler = async ( + event: ojSelectSingle.valueChanged + ) => { + let selected = event.detail.value; + let finetune = false; + const asyncIterator = modelDP.current.fetchFirst()[Symbol.asyncIterator](); + let result = await asyncIterator.next(); + let value = result.value; + let data = value.data as Array; + let idx = data.find((e: Model) => { + if (e.id === selected) return e; + }); + if (idx?.capabilities.includes("FINE_TUNE")) { + finetune = true; + let endpointId = endpoints.current?.find((e: Endpoint) => { + if (e.model === event.detail.value) { + return e.id; + } + }); + selected = endpointId ? endpointId.id : event.detail.value; + } + props.modelIdChange(selected, finetune); + }; + const modelTemplate = (item: any) => { return ( @@ -134,7 +196,7 @@ export const Settings = (props: Props) => { data={modelDP.current} labelHint={"Model"} itemText={"name"} - onvalueChanged={props.modelIdChange} + onvalueChanged={modelChangeHandler} > diff --git a/app/src/components/content/stomp-interface.tsx b/app/src/components/content/stomp-interface.tsx index 9216c609..ca4d3498 100644 --- a/app/src/components/content/stomp-interface.tsx +++ b/app/src/components/content/stomp-interface.tsx @@ -124,7 +124,8 @@ export const sendPrompt = ( client: Client | null, prompt: string, modelId: string, - convoId: string + convoId: string, + finetune: boolean ) => { if (client?.connected) { console.log("Sending prompt: ", prompt); @@ -134,6 +135,7 @@ export const sendPrompt = ( conversationId: convoId, content: prompt, modelId: modelId, + finetune: finetune, }), }); } else { diff --git a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/GenAIController.java b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/GenAIController.java index 9564e33c..650ddcc9 100644 --- a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/GenAIController.java +++ b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/GenAIController.java @@ -3,8 +3,12 @@ import com.oracle.bmc.generativeai.GenerativeAiClient; import com.oracle.bmc.generativeai.model.ModelCapability; import com.oracle.bmc.generativeai.requests.ListModelsRequest; +import com.oracle.bmc.generativeai.requests.ListEndpointsRequest; import com.oracle.bmc.generativeai.responses.ListModelsResponse; +import com.oracle.bmc.generativeai.responses.ListEndpointsResponse; +import com.oracle.bmc.generativeai.model.EndpointSummary; import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel; +import dev.victormartin.oci.genai.backend.backend.dao.GenAiEndpoint; import dev.victormartin.oci.genai.backend.backend.service.GenerativeAiClientService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,11 +37,25 @@ public List getModels() { GenerativeAiClient client = generativeAiClientService.getClient(); ListModelsResponse response = client.listModels(listModelsRequest); return response.getModelCollection().getItems().stream().map(m -> { - List capabilities = m.getCapabilities().stream().map(ModelCapability::getValue).collect(Collectors.toList()); - GenAiModel model = new GenAiModel(m.getId(),m.getDisplayName(), m.getVendor(), m.getVersion(), - capabilities, - m.getTimeCreated()); + List capabilities = m.getCapabilities().stream().map(ModelCapability::getValue) + .collect(Collectors.toList()); + GenAiModel model = new GenAiModel(m.getId(), m.getDisplayName(), m.getVendor(), m.getVersion(), + capabilities, m.getTimeCreated()); return model; }).collect(Collectors.toList()); } + + @GetMapping("/api/genai/endpoints") + public List getEndpoints() { + logger.info("getEndpoints()"); + ListEndpointsRequest listEndpointsRequest = ListEndpointsRequest.builder().compartmentId(COMPARTMENT_ID) + .build(); + GenerativeAiClient client = generativeAiClientService.getClient(); + ListEndpointsResponse response = client.listEndpoints(listEndpointsRequest); + return response.getEndpointCollection().getItems().stream().map(e -> { + GenAiEndpoint endpoint = new GenAiEndpoint(e.getId(), e.getDisplayName(), e.getLifecycleState(), + e.getModelId(), e.getTimeCreated()); + return endpoint; + }).collect(Collectors.toList()); + } } diff --git a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/PromptController.java b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/PromptController.java index 7fc05eba..c7005f74 100644 --- a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/PromptController.java +++ b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/PromptController.java @@ -45,6 +45,7 @@ public PromptController(InteractionRepository interactionRepository, OCIGenAISer @SendToUser("/queue/answer") public Answer handlePrompt(Prompt prompt) { String promptEscaped = HtmlUtils.htmlEscape(prompt.content()); + boolean finetune = prompt.finetune(); String activeModel = (prompt.modelId() == null) ? hardcodedChatModelId : prompt.modelId(); logger.info("Prompt " + promptEscaped + " received, on model " + activeModel); @@ -59,11 +60,8 @@ public Answer handlePrompt(Prompt prompt) { if (prompt.content().isEmpty()) { throw new InvalidPromptRequest(); } - // if (prompt.modelId() == null || - // !prompt.modelId().startsWith("ocid1.generativeaimodel.")) { throw new - // InvalidPromptRequest(); } saved.setDatetimeResponse(new Date()); - String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel); + String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel, finetune); saved.setResponse(responseFromGenAI); interactionRepository.save(saved); return new Answer(responseFromGenAI, ""); diff --git a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/dao/GenAiEndpoint.java b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/dao/GenAiEndpoint.java new file mode 100644 index 00000000..85d1d53c --- /dev/null +++ b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/dao/GenAiEndpoint.java @@ -0,0 +1,7 @@ +package dev.victormartin.oci.genai.backend.backend.dao; + +import java.util.Date; +import com.oracle.bmc.generativeai.model.Endpoint; + +public record GenAiEndpoint(String id, String name, Endpoint.LifecycleState state, String model, Date timeCreated) { +} diff --git a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/dao/Prompt.java b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/dao/Prompt.java index 54a23559..42cf7071 100644 --- a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/dao/Prompt.java +++ b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/dao/Prompt.java @@ -1,3 +1,4 @@ package dev.victormartin.oci.genai.backend.backend.dao; -public record Prompt(String content, String conversationId, String modelId) {}; \ No newline at end of file +public record Prompt(String content, String conversationId, String modelId, boolean finetune) { +}; \ No newline at end of file diff --git a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/OCIGenAIService.java b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/OCIGenAIService.java index 6da6f0aa..23610453 100644 --- a/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/OCIGenAIService.java +++ b/backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/OCIGenAIService.java @@ -14,56 +14,56 @@ @Service public class OCIGenAIService { - @Value("${genai.compartment_id}") - private String COMPARTMENT_ID; + @Value("${genai.compartment_id}") + private String COMPARTMENT_ID; - @Autowired - private GenerativeAiInferenceClientService generativeAiInferenceClientService; + @Autowired + private GenerativeAiInferenceClientService generativeAiInferenceClientService; - public String resolvePrompt(String input, String modelId) { - // Build generate text request, send, and get response - CohereLlmInferenceRequest llmInferenceRequest = - CohereLlmInferenceRequest.builder() - .prompt(input) - .maxTokens(600) - .temperature((double)1) - .frequencyPenalty((double)0) - .topP((double)0.75) - .isStream(false) - .isEcho(false) - .build(); + public String resolvePrompt(String input, String modelId, boolean finetune) { + // Build generate text request, send, and get response + CohereLlmInferenceRequest llmInferenceRequest = CohereLlmInferenceRequest.builder() + .prompt(input) + .maxTokens(600) + .temperature((double) 1) + .frequencyPenalty((double) 0) + .topP((double) 0.75) + .isStream(false) + .isEcho(false) + .build(); - GenerateTextDetails generateTextDetails = GenerateTextDetails.builder() - .servingMode(OnDemandServingMode.builder().modelId(modelId).build()) - .compartmentId(COMPARTMENT_ID) - .inferenceRequest(llmInferenceRequest) - .build(); - GenerateTextRequest generateTextRequest = GenerateTextRequest.builder() - .generateTextDetails(generateTextDetails) - .build(); - GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient(); - GenerateTextResponse generateTextResponse = client.generateText(generateTextRequest); - CohereLlmInferenceResponse response = - (CohereLlmInferenceResponse) generateTextResponse.getGenerateTextResult().getInferenceResponse(); - String responseTexts = response.getGeneratedTexts() - .stream() - .map(t -> t.getText()) - .collect(Collectors.joining(",")); - return responseTexts; - } + GenerateTextDetails generateTextDetails = GenerateTextDetails.builder() + .servingMode(finetune ? DedicatedServingMode.builder().endpointId(modelId).build() + : OnDemandServingMode.builder().modelId(modelId).build()) + .compartmentId(COMPARTMENT_ID) + .inferenceRequest(llmInferenceRequest) + .build(); + GenerateTextRequest generateTextRequest = GenerateTextRequest.builder() + .generateTextDetails(generateTextDetails) + .build(); + GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient(); + GenerateTextResponse generateTextResponse = client.generateText(generateTextRequest); + CohereLlmInferenceResponse response = (CohereLlmInferenceResponse) generateTextResponse + .getGenerateTextResult().getInferenceResponse(); + String responseTexts = response.getGeneratedTexts() + .stream() + .map(t -> t.getText()) + .collect(Collectors.joining(",")); + return responseTexts; + } - public String summaryText(String input, String modelId) { - SummarizeTextDetails summarizeTextDetails = SummarizeTextDetails.builder() - .servingMode(OnDemandServingMode.builder().modelId(modelId).build()) - .compartmentId(COMPARTMENT_ID) - .input(input) - .build(); - SummarizeTextRequest request = SummarizeTextRequest.builder() - .summarizeTextDetails(summarizeTextDetails) - .build(); - GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient(); - SummarizeTextResponse summarizeTextResponse = client.summarizeText(request); - String summaryText = summarizeTextResponse.getSummarizeTextResult().getSummary(); - return summaryText; - } + public String summaryText(String input, String modelId) { + SummarizeTextDetails summarizeTextDetails = SummarizeTextDetails.builder() + .servingMode(OnDemandServingMode.builder().modelId(modelId).build()) + .compartmentId(COMPARTMENT_ID) + .input(input) + .build(); + SummarizeTextRequest request = SummarizeTextRequest.builder() + .summarizeTextDetails(summarizeTextDetails) + .build(); + GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient(); + SummarizeTextResponse summarizeTextResponse = client.summarizeText(request); + String summaryText = summarizeTextResponse.getSummarizeTextResult().getSummary(); + return summaryText; + } }