-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error converting Microsoft Phi3 model to ONNX using Python and Transformers #21518
Comments
There's a converted onnx model available here: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx |
yeah thanks but here the error i get using that converted ONNX Phi3 Model while using onnxruntime flutter to initialize the model https://pub.dev/packages/onnxruntime import 'dart:io';
import 'package:flutter/services.dart';
import 'package:onnxruntime/onnxruntime.dart';
import 'package:path_provider/path_provider.dart';
class Phi3ChatModel {
final OrtSessionOptions _sessionOptions;
late OrtSession _session;
Phi3ChatModel() : _sessionOptions = OrtSessionOptions() {
OrtEnv.instance.init();
}
Future<void> initModel() async {
print("-----------------------------------------initiation start -");
_sessionOptions.setInterOpNumThreads(1);
print("-----------------------------------------initiation step 1 -");
_sessionOptions.setIntraOpNumThreads(1);
print("-----------------------------------------initiation step 2 -");
_sessionOptions.setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
print("-----------------------------------------initiation step 3 -");
const assetFileName = 'assets/models/phi3_cpu.onnx';
final rawAssetFile = await rootBundle.load(assetFileName);
final bytes = rawAssetFile.buffer.asUint8List();
print("-----------------------------------------initiation step 4 -");
_session = OrtSession.fromBuffer(bytes, _sessionOptions);
print("-----------------------------------------initiation step end -");
}
Future<String> predict(String inputData) async {
final inputTensor = OrtValueTensor.createTensorWithDataList([inputData], [1]);
final inputs = {'input': inputTensor};
final outputs = await _session.runAsync(OrtRunOptions(), inputs);
inputTensor.release();
final response = outputs?[0]?.value as List<String>;
outputs?.forEach((element) => element?.release());
return response.first;
}
void release() {
_sessionOptions.release();
_session.release();
OrtEnv.instance.release();
}
} import 'package:flutter/material.dart';
import 'package:onnxruntime_example/features/phi3_chat_model.dart';
class ONNXChatScreen extends StatefulWidget {
const ONNXChatScreen({super.key});
@override
_ONNXChatScreenState createState() => _ONNXChatScreenState();
}
class _ONNXChatScreenState extends State<ONNXChatScreen> {
final TextEditingController _controller = TextEditingController();
late Phi3ChatModel _chatbotModel;
List<String> _messages = [];
bool _isModelInitialized = false;
@override
void initState() {
super.initState();
_initializeModel();
}
Future<void> _initializeModel() async {
_chatbotModel = Phi3ChatModel();
try {
await _chatbotModel.initModel();
setState(() {
_isModelInitialized = true;
});
} catch (e) {
print('Erreur lors de l\'initialisation du modèle : $e');
}
}
@override
void dispose() {
_chatbotModel.release();
super.dispose();
}
void _sendMessage() async {
final message = _controller.text;
if (message.isEmpty) return;
try {
final response = await _chatbotModel.predict(message);
setState(() {
_messages.add('You: $message');
_messages.add('Bot: $response');
_controller.clear();
});
} catch (e) {
print('Erreur lors de l\'envoi du message : $e');
}
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(title: const Text('ONNX Chatbot')),
body: !_isModelInitialized
? const Center(child: CircularProgressIndicator())
: Column(
children: <Widget>[
Expanded(
child: ListView.builder(
itemCount: _messages.length,
itemBuilder: (context, index) => ListTile(
title: Text(_messages[index]),
),
),
),
Padding(
padding: const EdgeInsets.all(8.0),
child: Row(
children: <Widget>[
Expanded(
child: TextField(controller: _controller),
),
IconButton(
icon: const Icon(Icons.send),
onPressed: _sendMessage,
),
],
),
),
],
),
);
}
} here i downloaded the file at https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx?download=true and renamed it in my flutter assets, what could i have done bad please? |
thanks for the quick reply :) |
Here are some alternative options to export Phi-3 mini to ONNX. ONNX RuntimeYou can use ONNX Runtime's python -m onnxruntime.transformers.models.llama.convert_to_onnx -m microsoft/Phi-3-mini-4k-instruct --output ./phi3_mini_4k --precision fp32 --execution_provider cpu Hugging Face's OptimumInstead of using optimum-cli export onnx --model microsoft/Phi-3-mini-4k-instruct ./phi3_mini_4k or you can use a simple Python script from optimum.onnxruntime import ORTModelForCausalLM
model_name = "microsoft/Phi-3-mini-4k-instruct"
cache_dir = "./cache_dir"
model = ORTModelForCausalLM.from_pretrained(model_name, export=True, cache_dir=cache_dir)
model.save_pretrained("phi3_onnx/") |
@junssashu Let me know if Kunal's solution works for you! If not, feel free to re-ping on this issue :) |
Sorry for the time it took to respond. I've just abandoned that method. I switched to the GPT-2 Mini model, which actually loads successfully in my app. My next issue is how to infer on that loaded model to get a logical output for a chatbot app. from now on i'll mark the issue as close but i'm trying the solution here is my code to loads and use the model i don't know what i'm doing bad if someone can be of any helps i'll be gratefull import 'dart:io';
import 'dart:math';
import 'dart:typed_data';
import 'package:flutter/services.dart';
import 'package:onnxruntime/onnxruntime.dart';
import 'package:path_provider/path_provider.dart';
import 'dart:convert';
import 'package:flutter/services.dart' show rootBundle;
class Phi3ChatModel {
late OrtSession _session;
late OrtSessionOptions _sessionOptions;
late Map<String, dynamic> _tokenizerConfig;
late Map<String, dynamic> _generationConfig;
late Map<String, dynamic> _modelConfig;
late Map<String, dynamic> _tokenizer;
late Map<String, dynamic> _vocab;
late Map<String, dynamic> _specialTokensMap;
late int _vocabSize;
late String _bosToken;
late String _eosToken;
late String _unkToken;
late int _maxLength;
Phi3ChatModel() : _sessionOptions = OrtSessionOptions() {
OrtEnv.instance.init();
}
Future<void> initModel() async {
try {
print("Initializing GPT-2 model...");
_sessionOptions = OrtSessionOptions()
..setInterOpNumThreads(1)
..setIntraOpNumThreads(1)
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
final appDocDir = await getApplicationDocumentsDirectory();
// Load ONNX model
final modelFile = File('${appDocDir.path}/decoder_model.onnx');
if (!await modelFile.exists()) {
final data = await rootBundle.load('assets/models/gpt2/decoder_model.onnx');
await modelFile.writeAsBytes(data.buffer.asUint8List());
}
_session = OrtSession.fromFile(modelFile, _sessionOptions);
// Load configuration files
_tokenizerConfig = await _loadJsonFile('assets/models/gpt2/tokenizer_config.json');
_generationConfig = await _loadJsonFile('assets/models/gpt2/generation_config.json');
_modelConfig = await _loadJsonFile('assets/models/gpt2/config.json');
_tokenizer = await _loadJsonFile('assets/models/gpt2/tokenizer.json');
_specialTokensMap = await _loadJsonFile('assets/models/gpt2/special_tokens_map.json');
_vocab = await _loadJsonFile('assets/models/gpt2/vocab.json');
// Initialize model parameters
_vocabSize = _modelConfig['vocab_size'];
_bosToken = _specialTokensMap['bos_token'];
_eosToken = _specialTokensMap['eos_token'];
_unkToken = _specialTokensMap['unk_token'];
_maxLength = _modelConfig['n_positions'];
print("GPT-2 model initialized successfully.");
} catch (e) {
print("Error initializing GPT-2 model: $e");
rethrow;
}
}
Future<Map<String, dynamic>> _loadJsonFile(String path) async {
try {
print("Loading JSON file $path...");
String jsonString = await rootBundle.loadString(path);
print("JSON file $path loaded successfully.");
return json.decode(jsonString);
} catch (e) {
print("Error loading JSON file $path: $e");
rethrow;
}
}
List<int> encode(String text) {
print("Encoding text: \"$text\"");
List<int> tokens = [];
for (String word in text.split(' ')) {
if (_tokenizer['model']['vocab'].containsKey(word)) {
tokens.add(_tokenizer['model']['vocab'][word]);
} else {
// TODO: handle unknown words
}
}
print("Encoded text: $tokens");
return tokens;
}
String decode(List<int> tokens) {
print("Decoding tokens: $tokens");
String text = '';
for (int token in tokens) {
if (_tokenizer['model']['vocab'].containsValue(token)) {
String word = _tokenizer['model']['vocab'].keys.firstWhere((key) => _tokenizer['model']['vocab'][key] == token);
text += '$word ';
} else {
// TODO: handle unknown tokens
}
}
print("Decoded text: \"$text\"");
return text.trim();
}
void processOutputTokens(Object outputTokensObject, List<int> generatedTokens) {
if (outputTokensObject is List) {
if (outputTokensObject[0] is List) {
for (var item in outputTokensObject) {
processOutputTokens(item, generatedTokens);
}
}
if (outputTokensObject[0] is double) {
var probs = _softmax(outputTokensObject as List<double>);
var sample = _sampleFromProbs(probs);
generatedTokens.add(sample);
}
}
}
Future<String> generateText(String prompt, {int maxNewTokens = 50}) async {
try {
print("Generating text with prompt: \"$prompt\"...");
// Encode the prompt
List<int> inputIds = encode(prompt);
inputIds.insert(0, _modelConfig['bos_token_id']); // Beginning of sequence token
inputIds.add(_modelConfig['eos_token_id']);
// Initialize an empty list to store the generated tokens
List<int> generatedTokens = [];
var inputTensor = OrtValueTensor.createTensorWithDataList(
inputIds,
[1, inputIds.length],
);
// Create attention mask (all ones since there is no padding)
var attentionMask = OrtValueTensor.createTensorWithDataList(
Int64List.fromList(List.filled(inputIds.length, 1)),
[1, inputIds.length],
);
var ortInput = {
'input_ids': inputTensor,
'attention_mask': attentionMask
};
final outputs = _session.run(
OrtRunOptions( ),
ortInput
);
var out1 = outputs[0]?.value as List;
processOutputTokens(out1, generatedTokens);
inputTensor.release();
attentionMask.release();
// Decode the generated tokens
String result = decode(generatedTokens);
print("Generated text: \"$result\"");
return result;
} catch (e) {
print("Error generating text: $e");
return "Error: Unable to generate text.";
}
}
// Helper function to calculate softmax
List<double> _softmax(List<double> logits) {
print("Calculating softmax for logits: $logits");
double maxLogit = logits.reduce((a, b) => a > b ? a : b);
List<double> expLogits = logits.map((logit) => exp(logit - maxLogit)).toList();
double sum = expLogits.reduce((a, b) => a + b);
List<double> result = expLogits.map((expLogit) => expLogit / sum).toList();
print("Softmax result: $result");
return result;
}
// Helper function to sample from probabilities
int _sampleFromProbs(List<double> probs) {
print("Sampling from probabilities: $probs");
Random rand = Random();
double cumulativeProb = 0.0;
for (int i = 0; i < probs.length; i++) {
cumulativeProb += probs[i];
if (rand.nextDouble() < cumulativeProb) {
int result = i;
print("Sampled token: $result");
return result;
}
}
int result = probs.length - 1; // Fallback to last token
print("Sampled token: $result");
return result;
}
void release() {
print("Releasing resources...");
_sessionOptions.release();
_session.release();
OrtEnv.instance.release();
print("Resources released.");
}
} i'm using onnx runtime flutter package |
Describe the issue
Context
I encountered an error while attempting to convert a Microsoft Phi3 model to ONNX format using Python and the Transformers library. The conversion process fails with a
KeyError
indicating that the Phi3 model is not supported.Aditionnal Context
I followed the standard procedure for model conversion using the Transformers library. However, it appears that the Phi3 model is not listed among the supported models for conversion. I am looking for guidance on how to proceed with this conversion or potential workarounds.
Error
Expected Behavior
Successful conversion of the Phi3 model to ONNX format.
Environment
To reproduce
Steps to Reproduce the Behavior
onnxruntime
,transformers
).transformers.onnx
module.Code Snippet
The text was updated successfully, but these errors were encountered: