-
Notifications
You must be signed in to change notification settings - Fork 331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vraspar/phi 3 ios update #467
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,124 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import SwiftUI | ||
|
||
|
||
struct Message: Identifiable { | ||
let id = UUID() | ||
var text: String | ||
let isUser: Bool | ||
} | ||
|
||
struct ContentView: View { | ||
@ObservedObject var tokenUpdater = SharedTokenUpdater.shared | ||
@State private var userInput: String = "" | ||
@State private var messages: [Message] = [] // Store chat messages locally | ||
@State private var isGenerating: Bool = false // Track token generation state | ||
@State private var stats: String = "" // token genetation stats | ||
|
||
var body: some View { | ||
VStack { | ||
// ChatBubbles | ||
ScrollView { | ||
VStack(alignment: .leading) { | ||
ForEach(tokenUpdater.decodedTokens, id: \.self) { token in | ||
Text(token) | ||
.padding(.horizontal, 5) | ||
VStack(alignment: .leading, spacing: 20) { | ||
ForEach(messages) { message in | ||
ChatBubble(text: message.text, isUser: message.isUser) | ||
.padding(.horizontal, 20) | ||
} | ||
if !stats.isEmpty { | ||
Text(stats) | ||
.font(.footnote) | ||
.foregroundColor(.gray) | ||
.padding(.horizontal, 20) | ||
.padding(.top, 5) | ||
.multilineTextAlignment(.center) | ||
} | ||
} | ||
.padding() | ||
.padding(.top, 20) | ||
} | ||
Button("Generate Tokens") { | ||
DispatchQueue.global(qos: .background).async { | ||
// TODO: add user prompt question UI | ||
GenAIGenerator.generate("Who is the current US president?"); | ||
|
||
|
||
// User input | ||
HStack { | ||
TextField("Type your message...", text: $userInput) | ||
.padding() | ||
.background(Color(.systemGray6)) | ||
.cornerRadius(20) | ||
.padding(.horizontal) | ||
|
||
Button(action: { | ||
// Check for non-empty input | ||
guard !userInput.trimmingCharacters(in: .whitespaces).isEmpty else { return } | ||
|
||
messages.append(Message(text: userInput, isUser: true)) | ||
messages.append(Message(text: "", isUser: false)) // Placeholder for AI response | ||
|
||
|
||
// clear previously generated tokens | ||
SharedTokenUpdater.shared.clearTokens() | ||
|
||
let prompt = userInput | ||
userInput = "" | ||
isGenerating = true | ||
|
||
|
||
DispatchQueue.global(qos: .background).async { | ||
GenAIGenerator.generate(prompt) | ||
} | ||
}) { | ||
Image(systemName: "paperplane.fill") | ||
.foregroundColor(.white) | ||
.padding() | ||
.background(isGenerating ? Color.gray : Color.pastelGreen) | ||
.clipShape(Circle()) | ||
.padding(.trailing, 10) | ||
} | ||
.disabled(isGenerating) | ||
} | ||
.padding(.bottom, 20) | ||
} | ||
.background(Color(.systemGroupedBackground)) | ||
.edgesIgnoringSafeArea(.bottom) | ||
.onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationCompleted"))) { _ in | ||
isGenerating = false // Re-enable the button when token generation is complete | ||
} | ||
.onReceive(SharedTokenUpdater.shared.$decodedTokens) { tokens in | ||
// update model response | ||
if let lastIndex = messages.lastIndex(where: { !$0.isUser }) { | ||
let combinedText = tokens.joined(separator: "") | ||
messages[lastIndex].text = combinedText | ||
} | ||
} | ||
.onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationStats"))) { notification in | ||
if let userInfo = notification.userInfo, | ||
let totalTime = userInfo["totalTime"] as? Int, | ||
let firstTokenTime = userInfo["firstTokenTime"] as? Int, | ||
let tokenCount = userInfo["tokenCount"] as? Int { | ||
stats = "Generated \(tokenCount) tokens in \(totalTime) ms. First token in \(firstTokenTime) ms." | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we also include something like this:
the token generation and prompt processing rates may be different, and it might be useful to get a sense of both. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be nice to have prompt tokens/s as well given that and generation tokens/s are the usual metrics that gets compared. That would require something to return the number of tokens in the prompt (possibly the length of the input sequence post-tokenization) as IIRC it's not 1:1 with the number of words. |
||
} | ||
} | ||
} | ||
|
||
struct ChatBubble: View { | ||
var text: String | ||
var isUser: Bool | ||
|
||
var body: some View { | ||
HStack { | ||
if isUser { | ||
Spacer() | ||
Text(text) | ||
.padding() | ||
.background(Color.pastelGreen) | ||
.foregroundColor(.white) | ||
.cornerRadius(25) | ||
.padding(.horizontal, 10) | ||
} else { | ||
Text(text) | ||
.padding() | ||
.background(Color(.systemGray5)) | ||
.foregroundColor(.black) | ||
.cornerRadius(25) | ||
.padding(.horizontal, 20) | ||
Spacer() | ||
} | ||
} | ||
} | ||
|
@@ -32,3 +129,8 @@ struct ContentView_Previews: PreviewProvider { | |
ContentView() | ||
} | ||
} | ||
|
||
// Extension for a pastel green color | ||
extension Color { | ||
static let pastelGreen = Color(red: 0.6, green: 0.9, blue: 0.6) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,45 +5,109 @@ | |
#include "LocalLLM-Swift.h" | ||
#include "ort_genai.h" | ||
#include "ort_genai_c.h" | ||
|
||
#include <chrono> | ||
|
||
@implementation GenAIGenerator | ||
|
||
+ (void)generate:(nonnull NSString*)input_user_question { | ||
NSString* llmPath = [[NSBundle mainBundle] resourcePath]; | ||
const char* modelPath = llmPath.cString; | ||
|
||
auto model = OgaModel::Create(modelPath); | ||
auto tokenizer = OgaTokenizer::Create(*model); | ||
|
||
NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question]; | ||
const char* prompt = [promptString UTF8String]; | ||
|
||
auto sequences = OgaSequences::Create(); | ||
tokenizer->Encode(prompt, *sequences); | ||
typedef std::chrono::high_resolution_clock Clock; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: consider steady_clock https://en.cppreference.com/w/cpp/chrono/steady_clock |
||
typedef std::chrono::time_point<Clock> TimePoint; | ||
|
||
auto params = OgaGeneratorParams::Create(*model); | ||
params->SetSearchOption("max_length", 200); | ||
params->SetInputSequences(*sequences); | ||
|
||
// Streaming Output to generate token by token | ||
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer); | ||
|
||
auto generator = OgaGenerator::Create(*model, *params); | ||
+ (void)generate:(nonnull NSString*)input_user_question { | ||
NSLog(@"Starting token generation..."); | ||
|
||
NSString* llmPath = [[NSBundle mainBundle] resourcePath]; | ||
const char* modelPath = llmPath.cString; | ||
|
||
// Log model creation | ||
NSLog(@"Creating model ..."); | ||
auto model = OgaModel::Create(modelPath); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. performance-wise, it's probably nicer to not re-create the model every time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it should definitely be nicer as the first request is usually the slowest (and typically when doing perf testing we do a warmup query first and exclude that from timing data). |
||
if (!model) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the C++ API throws exceptions. might be simplest to just put most of this method into a |
||
NSLog(@"Failed to create model."); | ||
return; | ||
} | ||
|
||
NSLog(@"Creating tokenizer..."); | ||
auto tokenizer = OgaTokenizer::Create(*model); | ||
if (!tokenizer) { | ||
NSLog(@"Failed to create tokenizer."); | ||
return; | ||
} | ||
Comment on lines
+29
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tokenizer could also be created once and re-used as I believe it's tied to the model not the prompt so can be re-used. |
||
|
||
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer); | ||
|
||
// Construct the prompt | ||
NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question]; | ||
const char* prompt = [promptString UTF8String]; | ||
|
||
NSLog(@"Encoding prompt..."); | ||
auto sequences = OgaSequences::Create(); | ||
tokenizer->Encode(prompt, *sequences); | ||
|
||
// Log parameters | ||
NSLog(@"Setting generator parameters..."); | ||
auto params = OgaGeneratorParams::Create(*model); | ||
params->SetSearchOption("max_length", 200); | ||
params->SetInputSequences(*sequences); | ||
|
||
NSLog(@"Creating generator..."); | ||
auto generator = OgaGenerator::Create(*model, *params); | ||
|
||
bool isFirstToken = true; | ||
TimePoint startTime = Clock::now(); | ||
TimePoint firstTokenTime; | ||
int tokenCount = 0; | ||
|
||
NSLog(@"Starting token generation loop..."); | ||
while (!generator->IsDone()) { | ||
generator->ComputeLogits(); | ||
generator->GenerateNextToken(); | ||
|
||
if (isFirstToken) { | ||
NSLog(@"First token generated."); | ||
firstTokenTime = Clock::now(); | ||
Comment on lines
+66
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: logging takes time so we want to record firstTokenTime prior to logging. might also be interesting to include per-token timing in the log as that would help get a picture of performance throughout the generation phase. e.g. is the time per token consistent? if not is the variability random or does it gradually increase/decrease? that can provide hints as to potential causes of performance issues (if there are any). if we do that we might want to accumulate per-token times in a list and log at the end, as inserting log calls for every token inside the loop could significantly affect overall time taken (esp. is the log call is synchronous which apparently NSLog is). |
||
isFirstToken = false; | ||
} | ||
|
||
// Get the sequence data | ||
const int32_t* seq = generator->GetSequenceData(0); | ||
size_t seq_len = generator->GetSequenceCount(0); | ||
|
||
// Decode the new token | ||
const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]); | ||
|
||
// Check for decoding failure | ||
if (!decode_tokens) { | ||
NSLog(@"Token decoding failed."); | ||
break; | ||
} | ||
|
||
NSLog(@"Decoded token: %s", decode_tokens); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: might want a setting to control whether we log inside the loop (if we're after the best perf numbers possible we probably don't want to do that), or a way to exclude time in calls to NSLog from the total. |
||
tokenCount++; | ||
|
||
// Convert token to NSString and update UI on the main thread | ||
NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens]; | ||
[SharedTokenUpdater.shared addDecodedToken:decodedTokenString]; | ||
} | ||
|
||
while (!generator->IsDone()) { | ||
generator->ComputeLogits(); | ||
generator->GenerateNextToken(); | ||
|
||
const int32_t* seq = generator->GetSequenceData(0); | ||
size_t seq_len = generator->GetSequenceCount(0); | ||
const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]); | ||
TimePoint endTime = Clock::now(); | ||
auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count(); | ||
auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count(); | ||
|
||
NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %d", totalDuration, firstTokenDuration, tokenCount); | ||
|
||
NSLog(@"Decoded tokens: %s", decode_tokens); | ||
NSDictionary *stats = @{ | ||
@"totalTime": @(totalDuration), | ||
@"firstTokenTime": @(firstTokenDuration), | ||
@"tokenCount": @(tokenCount) | ||
}; | ||
|
||
// Add decoded token to SharedTokenUpdater | ||
NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens]; | ||
[SharedTokenUpdater.shared addDecodedToken:decodedTokenString]; | ||
} | ||
// notify main thread that token generation is complete | ||
dispatch_async(dispatch_get_main_queue(), ^{ | ||
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil]; | ||
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats]; | ||
}); | ||
NSLog(@"Token generation completed."); | ||
} | ||
|
||
@end |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❤️ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we keep the copyright notice?