Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into baijumeswani/decoder-pipeline
  • Loading branch information
baijumeswani committed Dec 17, 2024
2 parents 43af9aa + 10932c1 commit 49f4345
Show file tree
Hide file tree
Showing 19 changed files with 183 additions and 67 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/ios-build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: "iOS ARM64 Build"
on:
workflow_dispatch:
push:
branches:
- main
- rel-*
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
iphonesimulator-arm64-build:
runs-on: macos-latest # arm64
steps:
- name: Checkout OnnxRuntime GenAI repo
uses: actions/checkout@v4
with:
submodules: true

- uses: actions/setup-python@v5
with:
python-version: '3.12.x'

- name: Install the python wheel and dependencies
run: |
python3 -m venv genai-macos-venv
source genai-macos-venv/bin/activate
python3 -m pip install requests
- name: Run iOS Build
run: |
set -e -x
source genai-macos-venv/bin/activate
python3 build.py --ios \
--parallel \
--apple_sysroot iphonesimulator \
--osx_arch arm64 \
--apple_deploy_target 15.4 \
--cmake_generator 'Xcode' \
--build_dir build_iphonesimulator
10 changes: 6 additions & 4 deletions src/java/src/main/java/ai/onnxruntime/genai/Adapters.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package ai.onnxruntime.genai;

/** A container of adapters. */
public final class Adapters implements AutoCloseable {
private long nativeHandle = 0;

Expand All @@ -22,13 +23,13 @@ public Adapters(Model model) throws GenAIException {
}

/**
* Load an adapter from the specified path.
* Loads the model adapter from the given adapter file path and adapter name.
*
* @param adapterFilePath The path of the adapter.
* @param adapterName A unique user supplied adapter identifier.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public void loadAdapters(String adapterFilePath, String adapterName) throws GenAIException {
public void loadAdapter(String adapterFilePath, String adapterName) throws GenAIException {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}
Expand All @@ -37,12 +38,13 @@ public void loadAdapters(String adapterFilePath, String adapterName) throws GenA
}

/**
* Unload an adapter.
* Unloads the adapter with the given identifier from the previosly loaded adapters. If the
* adapter is not found, or if it cannot be unloaded (when it is in use), an error is returned.
*
* @param adapterName A unique user supplied adapter identifier.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public void unloadAdapters(String adapterName) throws GenAIException {
public void unloadAdapter(String adapterName) throws GenAIException {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}
Expand Down
34 changes: 29 additions & 5 deletions src/java/src/main/java/ai/onnxruntime/genai/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,56 @@
*/
package ai.onnxruntime.genai;

/**
* Use Config to set the ORT execution providers (EPs) and their options. The EPs are applied based on
* insertion order.
*/
public final class Config implements AutoCloseable {
private long nativeHandle;

/**
* Creates a Config from the given configuration directory.
*
* @param modelPath The path to the configuration directory.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public Config(String modelPath) throws GenAIException {
nativeHandle = createConfig(modelPath);
}

/** Clear the list of providers in the config */
public void clearProviders() {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}
clearProviders(nativeHandle);
}

public void appendProvider(String provider_name) {
/**
* Add the provider at the end of the list of providers in the given config if it doesn't already
* exist. If it already exists, does nothing.
*
* @param providerName The provider name.
*/
public void appendProvider(String providerName) {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}
appendProvider(nativeHandle, provider_name);
appendProvider(nativeHandle, providerName);
}

public void setProviderOption(String provider_name, String option_name, String option_value) {
/**
* Set a provider option.
*
* @param providerName The provider name.
* @param optionKey The key of the option to set.
* @param optionValue The value of the option to set.
*/
public void setProviderOption(String providerName, String optionKey, String optionValue) {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}
setProviderOption(nativeHandle, provider_name, option_name, option_value);
setProviderOption(nativeHandle, providerName, optionKey, optionValue);
}

@Override
Expand Down Expand Up @@ -60,5 +84,5 @@ long nativeHandle() {
private native void appendProvider(long configHandle, String provider_name);

private native void setProviderOption(
long configHandle, String provider_name, String option_name, String option_value);
long configHandle, String providerName, String optionKey, String optionValue);
}
4 changes: 2 additions & 2 deletions src/java/src/main/java/ai/onnxruntime/genai/GenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ final class GenAI {
/** The short name of the ONNX runtime shared library */
static final String ONNXRUNTIME_LIBRARY_NAME = "onnxruntime";

/** The value of the {@link #GENAI_NATIVE_PATH} system property */
/** The value of the GENAI_NATIVE_PATH system property */
private static String libraryDirPathProperty;

/** The OS & CPU architecture string */
Expand Down Expand Up @@ -268,7 +268,7 @@ private static Optional<File> extractFromResources(String library) {

/**
* Maps the library name into a platform dependent library filename. Converts macOS's "jnilib" to
* "dylib" but otherwise is the same as {@link System#mapLibraryName(String)}.
* "dylib" but otherwise is the same as System#mapLibraryName(String).
*
* @param library The library name
* @return The library filename.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

/** An exception which contains the error message and code produced by the native layer. */
public final class GenAIException extends Exception {
public GenAIException(String message) {
GenAIException(String message) {
super(message);
}

public GenAIException(String message, Exception innerException) {
GenAIException(String message, Exception innerException) {
super(message, innerException);
}
}
11 changes: 7 additions & 4 deletions src/java/src/main/java/ai/onnxruntime/genai/Generator.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ public void appendTokenSequences(Sequences sequences) throws GenAIException {
}

/**
* Rewinds the generator by the specified number of tokens.
* Rewinds the generator to the given length. This is useful when the user wants to rewind the
* generator to a specific length and continue generating from that point.
*
* @param newLength The desired length in tokens after rewinding.
* @throws GenAIException If the call to the GenAI native API fails.
Expand All @@ -108,7 +109,8 @@ public void rewindTo(int newLength) throws GenAIException {
}

/**
* Generates the next token in the sequence.
* Computes the logits from the model based on the input ids and the past state. The computed
* logits are stored in the generator.
*
* @throws GenAIException If the call to the GenAI native API fails.
*/
Expand Down Expand Up @@ -151,9 +153,10 @@ public int getLastTokenInSequence(long sequenceIndex) throws GenAIException {
}

/**
* Fetches and returns the output tensor with the given name.
* Returns a copy of the model output identified by the given name as a Tensor.
*
* @param name The name of the output needed.
* @return The tensor.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public Tensor getOutput(String name) throws GenAIException {
Expand All @@ -162,7 +165,7 @@ public Tensor getOutput(String name) throws GenAIException {
}

/**
* Activates one of the loaded adapters.
* Sets the adapter with the given adapter name as active.
*
* @param adapters The Adapters container.
* @param adapterName The adapter name that was previously loaded.
Expand Down
26 changes: 23 additions & 3 deletions src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,34 @@
import java.nio.ByteBuffer;

/**
* The `GeneratorParams` class represents the parameters used for generating sequences with a model.
* Set the prompt using setInput, and any other search options using setSearchOption.
* Represents the parameters used for generating sequences with a model. Set the prompt using
* setInput, and any other search options using setSearchOption.
*/
public final class GeneratorParams implements AutoCloseable {
private long nativeHandle = 0;
private ByteBuffer tokenIdsBuffer;

GeneratorParams(Model model) throws GenAIException {
/**
* Creates a GeneratorParams from the given model.
*
* @param model The model to use.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public GeneratorParams(Model model) throws GenAIException {
if (model.nativeHandle() == 0) {
throw new IllegalStateException("model has been freed and is invalid");
}

nativeHandle = createGeneratorParams(model.nativeHandle());
}

/**
* Set seach option with double value.
*
* @param optionName The option name.
* @param value The option value.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public void setSearchOption(String optionName, double value) throws GenAIException {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
Expand All @@ -30,6 +43,13 @@ public void setSearchOption(String optionName, double value) throws GenAIExcepti
setSearchOptionNumber(nativeHandle, optionName, value);
}

/**
* Set search option with boolean value.
*
* @param optionName The option name.
* @param value The option value.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public void setSearchOption(String optionName, boolean value) throws GenAIException {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
Expand Down
7 changes: 7 additions & 0 deletions src/java/src/main/java/ai/onnxruntime/genai/Images.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
*/
package ai.onnxruntime.genai;

/** This class can load images from the given path and prepare them for processing. */
public class Images implements AutoCloseable {
private long nativeHandle;

/**
* Construct a Images instance.
*
* @param imagePath The image path.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public Images(String imagePath) throws GenAIException {
nativeHandle = loadImages(imagePath);
}
Expand Down
38 changes: 9 additions & 29 deletions src/java/src/main/java/ai/onnxruntime/genai/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,28 @@
*/
package ai.onnxruntime.genai;

/** An ORT GenAI model. */
public final class Model implements AutoCloseable {
private long nativeHandle;

public Model(String modelPath) throws GenAIException {
nativeHandle = createModel(modelPath);
}

public Model(Config config) throws GenAIException {
nativeHandle = createModelFromConfig(config.nativeHandle());
}

/**
* Creates a Tokenizer instance for this model. The model contains the configuration information
* that determines the tokenizer to use.
* Construct a Model from folder path.
*
* @return The Tokenizer instance.
* @param modelPath The path of the GenAI model.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public Tokenizer createTokenizer() throws GenAIException {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}

return new Tokenizer(this);
public Model(String modelPath) throws GenAIException {
nativeHandle = createModel(modelPath);
}

// NOTE: Having model.createGeneratorParams is still under discussion.
// model.createTokenizer is consistent with the python setup at least and agreed upon.

/**
* Creates a GeneratorParams instance for executing the model. NOTE: GeneratorParams internally
* uses the Model, so the Model instance must remain valid
* Construct a Model from the given Config.
*
* @return The GeneratorParams instance.
* @param config The config to use.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public GeneratorParams createGeneratorParams() throws GenAIException {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}

return new GeneratorParams(this);
public Model(Config config) throws GenAIException {
nativeHandle = createModelFromConfig(config.nativeHandle());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
public class MultiModalProcessor implements AutoCloseable {
private long nativeHandle;

/**
* Construct a MultiModalProcessor for a given model.
*
* @param model The model to be used.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public MultiModalProcessor(Model model) throws GenAIException {
assert (model.nativeHandle() != 0); // internal code should never pass an invalid model

Expand Down
8 changes: 8 additions & 0 deletions src/java/src/main/java/ai/onnxruntime/genai/NamedTensors.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@
*/
package ai.onnxruntime.genai;

/**
* This class is a list of tensors with names that match up with model input names.
*/
public class NamedTensors implements AutoCloseable {
private long nativeHandle;

/**
* Construct a NamedTensor from native handle.
*
* @param handle The native handle.
*/
public NamedTensors(long handle) {
nativeHandle = handle;
}
Expand Down
Loading

0 comments on commit 49f4345

Please sign in to comment.