Skip to content

Commit

Permalink
Refactor chat UI and token generation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
vraspar committed Oct 3, 2024
1 parent 05f7a72 commit 1b1d999
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 46 deletions.
128 changes: 115 additions & 13 deletions mobile/examples/phi-3/ios/LocalLLM/LocalLLM/ContentView.swift
Original file line number Diff line number Diff line change
@@ -1,27 +1,124 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import SwiftUI


struct Message: Identifiable {
let id = UUID()
var text: String
let isUser: Bool
}

struct ContentView: View {
@ObservedObject var tokenUpdater = SharedTokenUpdater.shared
@State private var userInput: String = ""
@State private var messages: [Message] = [] // Store chat messages locally
@State private var isGenerating: Bool = false // Track token generation state
@State private var stats: String = "" // token genetation stats

var body: some View {
VStack {
// ChatBubbles
ScrollView {
VStack(alignment: .leading) {
ForEach(tokenUpdater.decodedTokens, id: \.self) { token in
Text(token)
.padding(.horizontal, 5)
VStack(alignment: .leading, spacing: 20) {
ForEach(messages) { message in
ChatBubble(text: message.text, isUser: message.isUser)
.padding(.horizontal, 20)
}
if !stats.isEmpty {
Text(stats)
.font(.footnote)
.foregroundColor(.gray)
.padding(.horizontal, 20)
.padding(.top, 5)
.multilineTextAlignment(.center)
}
}
.padding()
.padding(.top, 20)
}
Button("Generate Tokens") {
DispatchQueue.global(qos: .background).async {
// TODO: add user prompt question UI
GenAIGenerator.generate("Who is the current US president?");


// User input
HStack {
TextField("Type your message...", text: $userInput)
.padding()
.background(Color(.systemGray6))
.cornerRadius(20)
.padding(.horizontal)

Button(action: {
// Check for non-empty input
guard !userInput.trimmingCharacters(in: .whitespaces).isEmpty else { return }

messages.append(Message(text: userInput, isUser: true))
messages.append(Message(text: "", isUser: false)) // Placeholder for AI response


// clear previously generated tokens
SharedTokenUpdater.shared.clearTokens()

let prompt = userInput
userInput = ""
isGenerating = true


DispatchQueue.global(qos: .background).async {
GenAIGenerator.generate(prompt)
}
}) {
Image(systemName: "paperplane.fill")
.foregroundColor(.white)
.padding()
.background(isGenerating ? Color.gray : Color.pastelGreen)
.clipShape(Circle())
.padding(.trailing, 10)
}
.disabled(isGenerating)
}
.padding(.bottom, 20)
}
.background(Color(.systemGroupedBackground))
.edgesIgnoringSafeArea(.bottom)
.onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationCompleted"))) { _ in
isGenerating = false // Re-enable the button when token generation is complete
}
.onReceive(SharedTokenUpdater.shared.$decodedTokens) { tokens in
// update model response
if let lastIndex = messages.lastIndex(where: { !$0.isUser }) {
let combinedText = tokens.joined(separator: "")
messages[lastIndex].text = combinedText
}
}
.onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationStats"))) { notification in
if let userInfo = notification.userInfo,
let totalTime = userInfo["totalTime"] as? Int,
let firstTokenTime = userInfo["firstTokenTime"] as? Int,
let tokenCount = userInfo["tokenCount"] as? Int {
stats = "Generated \(tokenCount) tokens in \(totalTime) ms. First token in \(firstTokenTime) ms."
}
}
}
}

struct ChatBubble: View {
var text: String
var isUser: Bool

var body: some View {
HStack {
if isUser {
Spacer()
Text(text)
.padding()
.background(Color.pastelGreen)
.foregroundColor(.white)
.cornerRadius(25)
.padding(.horizontal, 10)
} else {
Text(text)
.padding()
.background(Color(.systemGray5))
.foregroundColor(.black)
.cornerRadius(25)
.padding(.horizontal, 20)
Spacer()
}
}
}
Expand All @@ -32,3 +129,8 @@ struct ContentView_Previews: PreviewProvider {
ContentView()
}
}

// Extension for a pastel green color
extension Color {
static let pastelGreen = Color(red: 0.6, green: 0.9, blue: 0.6)
}
128 changes: 96 additions & 32 deletions mobile/examples/phi-3/ios/LocalLLM/LocalLLM/GenAIGenerator.mm
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,109 @@
#include "LocalLLM-Swift.h"
#include "ort_genai.h"
#include "ort_genai_c.h"

#include <chrono>

@implementation GenAIGenerator

+ (void)generate:(nonnull NSString*)input_user_question {
NSString* llmPath = [[NSBundle mainBundle] resourcePath];
const char* modelPath = llmPath.cString;

auto model = OgaModel::Create(modelPath);
auto tokenizer = OgaTokenizer::Create(*model);

NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question];
const char* prompt = [promptString UTF8String];

auto sequences = OgaSequences::Create();
tokenizer->Encode(prompt, *sequences);
typedef std::chrono::high_resolution_clock Clock;
typedef std::chrono::time_point<Clock> TimePoint;

auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 200);
params->SetInputSequences(*sequences);

// Streaming Output to generate token by token
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);

auto generator = OgaGenerator::Create(*model, *params);
+ (void)generate:(nonnull NSString*)input_user_question {
NSLog(@"Starting token generation...");

NSString* llmPath = [[NSBundle mainBundle] resourcePath];
const char* modelPath = llmPath.cString;

// Log model creation
NSLog(@"Creating model ...");
auto model = OgaModel::Create(modelPath);
if (!model) {
NSLog(@"Failed to create model.");
return;
}

NSLog(@"Creating tokenizer...");
auto tokenizer = OgaTokenizer::Create(*model);
if (!tokenizer) {
NSLog(@"Failed to create tokenizer.");
return;
}

auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);

// Construct the prompt
NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question];
const char* prompt = [promptString UTF8String];

NSLog(@"Encoding prompt...");
auto sequences = OgaSequences::Create();
tokenizer->Encode(prompt, *sequences);

// Log parameters
NSLog(@"Setting generator parameters...");
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 200);
params->SetInputSequences(*sequences);

NSLog(@"Creating generator...");
auto generator = OgaGenerator::Create(*model, *params);

bool isFirstToken = true;
TimePoint startTime = Clock::now();
TimePoint firstTokenTime;
int tokenCount = 0;

NSLog(@"Starting token generation loop...");
while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();

if (isFirstToken) {
NSLog(@"First token generated.");
firstTokenTime = Clock::now();
isFirstToken = false;
}

// Get the sequence data
const int32_t* seq = generator->GetSequenceData(0);
size_t seq_len = generator->GetSequenceCount(0);

// Decode the new token
const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]);

// Check for decoding failure
if (!decode_tokens) {
NSLog(@"Token decoding failed.");
break;
}

NSLog(@"Decoded token: %s", decode_tokens);
tokenCount++;

// Convert token to NSString and update UI on the main thread
NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens];
[SharedTokenUpdater.shared addDecodedToken:decodedTokenString];
}

while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();

const int32_t* seq = generator->GetSequenceData(0);
size_t seq_len = generator->GetSequenceCount(0);
const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]);
TimePoint endTime = Clock::now();
auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count();
auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count();

NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %d", totalDuration, firstTokenDuration, tokenCount);

NSLog(@"Decoded tokens: %s", decode_tokens);
NSDictionary *stats = @{
@"totalTime": @(totalDuration),
@"firstTokenTime": @(firstTokenDuration),
@"tokenCount": @(tokenCount)
};

// Add decoded token to SharedTokenUpdater
NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens];
[SharedTokenUpdater.shared addDecodedToken:decodedTokenString];
}
// notify main thread that token generation is complete
dispatch_async(dispatch_get_main_queue(), ^{
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil];
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats];
});
NSLog(@"Token generation completed.");
}

@end
4 changes: 3 additions & 1 deletion mobile/examples/phi-3/ios/LocalLLM/LocalLLM/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,6 @@ Upon app launching, Xcode will automatically copy and install the model files fr

**Note**: The current app only sets up with a simple initial prompt question, you can adjust/try your own or refine the UI based on requirements.

***Notice:*** The current Xcode project runs on iOS 16.6, feel free to adjust latest iOS/build for lates iOS versions accordingly.
***Notice:*** The current Xcode project runs on iOS 16.6, feel free to adjust latest iOS/build for lates iOS versions accordingly.

![alt text](<Simulator Screenshot - iPhone 16.png>)
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,10 @@ import Foundation
self.decodedTokens.append(token)
}
}

@objc func clearTokens() {
DispatchQueue.main.async {
self.decodedTokens.removeAll()
}
}
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 1b1d999

Please sign in to comment.