Skip to content

Commit

Permalink
Merge pull request #34 from oracle-devrel/backendSupportforDedicatedS…
Browse files Browse the repository at this point in the history
…ervingMode

updating to support dedicated serving mode with fine-tuned models
  • Loading branch information
vmleon authored Sep 23, 2024
2 parents aea7905 + 0343695 commit 7a41b57
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 67 deletions.
23 changes: 19 additions & 4 deletions app/src/components/content/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ type Chat = {
answer?: string;
loading?: string;
};
type Model = {
id: string;
name: string;
vendor: string;
version: string;
capabilities: Array<string>;
timeCreated: string;
};

const defaultServiceType: string = localStorage.getItem("service") || "text";
const defaultBackendType: string = localStorage.getItem("backend") || "java";
Expand All @@ -46,6 +54,7 @@ const Content = () => {
const question = useRef<string>();
const chatData = useRef<Array<object>>([]);
const socket = useRef<WebSocket>();
const finetune = useRef<boolean>(false);
const [client, setClient] = useState<Client | null>(null);

const messagesDP = useRef(
Expand Down Expand Up @@ -167,7 +176,13 @@ const Content = () => {
JSON.stringify({ msgType: "question", data: question.current })
);
} else {
sendPrompt(client, question.current!, modelId!, conversationId!);
sendPrompt(
client,
question.current!,
modelId!,
conversationId!,
finetune.current
);
}
}
};
Expand Down Expand Up @@ -199,9 +214,9 @@ const Content = () => {
localStorage.setItem("backend", backend);
location.reload();
};
const modelIdChangeHandler = (event: CustomEvent) => {
console.log("model Id: ", event.detail.value);
if (event.detail.value != null) setModelId(event.detail.value);
const modelIdChangeHandler = (value: string, modelType: boolean) => {
if (value != null) setModelId(value);
finetune.current = modelType;
};
const clearSummary = () => {
setSummaryResults("");
Expand Down
72 changes: 67 additions & 5 deletions app/src/components/content/settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import "oj-c/select-single";
import "ojs/ojlistitemlayout";
import "ojs/ojhighlighttext";
import MutableArrayDataProvider = require("ojs/ojmutablearraydataprovider");
import { ojSelectSingle } from "@oracle/oraclejet/ojselectsingle";

type ServiceTypeVal = "text" | "summary" | "sim";
type BackendTypeVal = "java" | "python";
Expand All @@ -17,7 +18,7 @@ type Props = {
backendType: BackendTypeVal;
aiServiceChange: (service: ServiceTypeVal) => void;
backendChange: (backend: BackendTypeVal) => void;
modelIdChange: (modelName: any) => void;
modelIdChange: (modelId: any, modelData: any) => void;
};

const serviceTypes = [
Expand All @@ -30,6 +31,21 @@ const backendTypes = [
{ value: "java", label: "Java" },
{ value: "python", label: "Python" },
];
type Model = {
id: string;
name: string;
vendor: string;
version: string;
capabilities: Array<string>;
timeCreated: string;
};
type Endpoint = {
id: string;
name: string;
state: string;
model: string;
timeCreated: string;
};
const serviceOptionsDP = new MutableArrayDataProvider<
Services["value"],
Services
Expand All @@ -50,8 +66,11 @@ export const Settings = (props: Props) => {
};

const modelDP = useRef(
new MutableArrayDataProvider<string, {}>([], { keyAttributes: "id" })
new MutableArrayDataProvider<string, {}>([], {
keyAttributes: "id",
})
);
const endpoints = useRef<Array<Endpoint>>();

const fetchModels = async () => {
try {
Expand All @@ -60,9 +79,8 @@ export const Settings = (props: Props) => {
throw new Error(`Response status: ${response.status}`);
}
const json = await response.json();
const result = json.filter((model: any) => {
const result = json.filter((model: Model) => {
if (
// model.capabilities.includes("FINE_TUNE") &&
model.capabilities.includes("TEXT_GENERATION") &&
(model.vendor == "cohere" || model.vendor == "") &&
model.version != "14.2"
Expand All @@ -77,11 +95,55 @@ export const Settings = (props: Props) => {
);
}
};
const fetchEndpoints = async () => {
try {
const response = await fetch("/api/genai/endpoints");
if (!response.ok) {
throw new Error(`Response status: ${response.status}`);
}
const json = await response.json();
const result = json.filter((endpoint: Endpoint) => {
// add filtering code here
return endpoint;
});
endpoints.current = result;
} catch (error: any) {
console.log(
"Java service not available for fetching list of Endpoints: ",
error.message
);
}
};

useEffect(() => {
fetchEndpoints();
fetchModels();
}, []);

const modelChangeHandler = async (
event: ojSelectSingle.valueChanged<string, {}>
) => {
let selected = event.detail.value;
let finetune = false;
const asyncIterator = modelDP.current.fetchFirst()[Symbol.asyncIterator]();
let result = await asyncIterator.next();
let value = result.value;
let data = value.data as Array<Model>;
let idx = data.find((e: Model) => {
if (e.id === selected) return e;
});
if (idx?.capabilities.includes("FINE_TUNE")) {
finetune = true;
let endpointId = endpoints.current?.find((e: Endpoint) => {
if (e.model === event.detail.value) {
return e.id;
}
});
selected = endpointId ? endpointId.id : event.detail.value;
}
props.modelIdChange(selected, finetune);
};

const modelTemplate = (item: any) => {
return (
<oj-list-item-layout class="oj-listitemlayout-padding-off">
Expand Down Expand Up @@ -134,7 +196,7 @@ export const Settings = (props: Props) => {
data={modelDP.current}
labelHint={"Model"}
itemText={"name"}
onvalueChanged={props.modelIdChange}
onvalueChanged={modelChangeHandler}
>
<template slot="itemTemplate" render={modelTemplate}></template>
</oj-c-select-single>
Expand Down
4 changes: 3 additions & 1 deletion app/src/components/content/stomp-interface.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ export const sendPrompt = (
client: Client | null,
prompt: string,
modelId: string,
convoId: string
convoId: string,
finetune: boolean
) => {
if (client?.connected) {
console.log("Sending prompt: ", prompt);
Expand All @@ -134,6 +135,7 @@ export const sendPrompt = (
conversationId: convoId,
content: prompt,
modelId: modelId,
finetune: finetune,
}),
});
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import com.oracle.bmc.generativeai.GenerativeAiClient;
import com.oracle.bmc.generativeai.model.ModelCapability;
import com.oracle.bmc.generativeai.requests.ListModelsRequest;
import com.oracle.bmc.generativeai.requests.ListEndpointsRequest;
import com.oracle.bmc.generativeai.responses.ListModelsResponse;
import com.oracle.bmc.generativeai.responses.ListEndpointsResponse;
import com.oracle.bmc.generativeai.model.EndpointSummary;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiEndpoint;
import dev.victormartin.oci.genai.backend.backend.service.GenerativeAiClientService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -33,11 +37,25 @@ public List<GenAiModel> getModels() {
GenerativeAiClient client = generativeAiClientService.getClient();
ListModelsResponse response = client.listModels(listModelsRequest);
return response.getModelCollection().getItems().stream().map(m -> {
List<String> capabilities = m.getCapabilities().stream().map(ModelCapability::getValue).collect(Collectors.toList());
GenAiModel model = new GenAiModel(m.getId(),m.getDisplayName(), m.getVendor(), m.getVersion(),
capabilities,
m.getTimeCreated());
List<String> capabilities = m.getCapabilities().stream().map(ModelCapability::getValue)
.collect(Collectors.toList());
GenAiModel model = new GenAiModel(m.getId(), m.getDisplayName(), m.getVendor(), m.getVersion(),
capabilities, m.getTimeCreated());
return model;
}).collect(Collectors.toList());
}

@GetMapping("/api/genai/endpoints")
public List<GenAiEndpoint> getEndpoints() {
logger.info("getEndpoints()");
ListEndpointsRequest listEndpointsRequest = ListEndpointsRequest.builder().compartmentId(COMPARTMENT_ID)
.build();
GenerativeAiClient client = generativeAiClientService.getClient();
ListEndpointsResponse response = client.listEndpoints(listEndpointsRequest);
return response.getEndpointCollection().getItems().stream().map(e -> {
GenAiEndpoint endpoint = new GenAiEndpoint(e.getId(), e.getDisplayName(), e.getLifecycleState(),
e.getModelId(), e.getTimeCreated());
return endpoint;
}).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public PromptController(InteractionRepository interactionRepository, OCIGenAISer
@SendToUser("/queue/answer")
public Answer handlePrompt(Prompt prompt) {
String promptEscaped = HtmlUtils.htmlEscape(prompt.content());
boolean finetune = prompt.finetune();
String activeModel = (prompt.modelId() == null) ? hardcodedChatModelId : prompt.modelId();
logger.info("Prompt " + promptEscaped + " received, on model " + activeModel);

Expand All @@ -59,11 +60,8 @@ public Answer handlePrompt(Prompt prompt) {
if (prompt.content().isEmpty()) {
throw new InvalidPromptRequest();
}
// if (prompt.modelId() == null ||
// !prompt.modelId().startsWith("ocid1.generativeaimodel.")) { throw new
// InvalidPromptRequest(); }
saved.setDatetimeResponse(new Date());
String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel);
String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel, finetune);
saved.setResponse(responseFromGenAI);
interactionRepository.save(saved);
return new Answer(responseFromGenAI, "");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package dev.victormartin.oci.genai.backend.backend.dao;

import java.util.Date;
import com.oracle.bmc.generativeai.model.Endpoint;

public record GenAiEndpoint(String id, String name, Endpoint.LifecycleState state, String model, Date timeCreated) {
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
package dev.victormartin.oci.genai.backend.backend.dao;

public record Prompt(String content, String conversationId, String modelId) {};
public record Prompt(String content, String conversationId, String modelId, boolean finetune) {
};
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,56 @@

@Service
public class OCIGenAIService {
@Value("${genai.compartment_id}")
private String COMPARTMENT_ID;
@Value("${genai.compartment_id}")
private String COMPARTMENT_ID;

@Autowired
private GenerativeAiInferenceClientService generativeAiInferenceClientService;
@Autowired
private GenerativeAiInferenceClientService generativeAiInferenceClientService;

public String resolvePrompt(String input, String modelId) {
// Build generate text request, send, and get response
CohereLlmInferenceRequest llmInferenceRequest =
CohereLlmInferenceRequest.builder()
.prompt(input)
.maxTokens(600)
.temperature((double)1)
.frequencyPenalty((double)0)
.topP((double)0.75)
.isStream(false)
.isEcho(false)
.build();
public String resolvePrompt(String input, String modelId, boolean finetune) {
// Build generate text request, send, and get response
CohereLlmInferenceRequest llmInferenceRequest = CohereLlmInferenceRequest.builder()
.prompt(input)
.maxTokens(600)
.temperature((double) 1)
.frequencyPenalty((double) 0)
.topP((double) 0.75)
.isStream(false)
.isEcho(false)
.build();

GenerateTextDetails generateTextDetails = GenerateTextDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.inferenceRequest(llmInferenceRequest)
.build();
GenerateTextRequest generateTextRequest = GenerateTextRequest.builder()
.generateTextDetails(generateTextDetails)
.build();
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
GenerateTextResponse generateTextResponse = client.generateText(generateTextRequest);
CohereLlmInferenceResponse response =
(CohereLlmInferenceResponse) generateTextResponse.getGenerateTextResult().getInferenceResponse();
String responseTexts = response.getGeneratedTexts()
.stream()
.map(t -> t.getText())
.collect(Collectors.joining(","));
return responseTexts;
}
GenerateTextDetails generateTextDetails = GenerateTextDetails.builder()
.servingMode(finetune ? DedicatedServingMode.builder().endpointId(modelId).build()
: OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.inferenceRequest(llmInferenceRequest)
.build();
GenerateTextRequest generateTextRequest = GenerateTextRequest.builder()
.generateTextDetails(generateTextDetails)
.build();
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
GenerateTextResponse generateTextResponse = client.generateText(generateTextRequest);
CohereLlmInferenceResponse response = (CohereLlmInferenceResponse) generateTextResponse
.getGenerateTextResult().getInferenceResponse();
String responseTexts = response.getGeneratedTexts()
.stream()
.map(t -> t.getText())
.collect(Collectors.joining(","));
return responseTexts;
}

public String summaryText(String input, String modelId) {
SummarizeTextDetails summarizeTextDetails = SummarizeTextDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.input(input)
.build();
SummarizeTextRequest request = SummarizeTextRequest.builder()
.summarizeTextDetails(summarizeTextDetails)
.build();
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
SummarizeTextResponse summarizeTextResponse = client.summarizeText(request);
String summaryText = summarizeTextResponse.getSummarizeTextResult().getSummary();
return summaryText;
}
public String summaryText(String input, String modelId) {
SummarizeTextDetails summarizeTextDetails = SummarizeTextDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.input(input)
.build();
SummarizeTextRequest request = SummarizeTextRequest.builder()
.summarizeTextDetails(summarizeTextDetails)
.build();
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
SummarizeTextResponse summarizeTextResponse = client.summarizeText(request);
String summaryText = summarizeTextResponse.getSummarizeTextResult().getSummary();
return summaryText;
}
}

0 comments on commit 7a41b57

Please sign in to comment.