Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batch, vector stores, and fine tuning LRO and subclient pattern updates #11

Merged
merged 22 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 21 additions & 57 deletions .dotnet.azure/src/Custom/Batch/AzureBatchClient.Protocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,41 @@

using System.ClientModel;
using System.ClientModel.Primitives;
using System.Text.Json;
using OpenAI.Batch;

namespace Azure.AI.OpenAI.Batch;

internal partial class AzureBatchClient : BatchClient
{
public override async Task<ClientResult> CreateBatchAsync(BinaryContent content, RequestOptions options = null)
public override async Task<CreateBatchOperation> CreateBatchAsync(BinaryContent content, bool waitUntilCompleted, RequestOptions options = null)
chschrae marked this conversation as resolved.
Show resolved Hide resolved
{
Argument.AssertNotNull(content, nameof(content));

using PipelineMessage message = CreateCreateBatchRequest(content, options);
return ClientResult.FromResponse(await Pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false));
PipelineResponse response = await Pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false);

using JsonDocument doc = JsonDocument.Parse(response.Content);
string batchId = doc.RootElement.GetProperty("id"u8).GetString();
string status = doc.RootElement.GetProperty("status"u8).GetString();

CreateBatchOperation operation = new(Pipeline, _endpoint, batchId, status, response);
return await operation.WaitUntilAsync(waitUntilCompleted, options).ConfigureAwait(false);
}

public override ClientResult CreateBatch(BinaryContent content, RequestOptions options = null)
public override CreateBatchOperation CreateBatch(BinaryContent content, bool waitUntilCompleted, RequestOptions options = null)
{
Argument.AssertNotNull(content, nameof(content));

using PipelineMessage message = CreateCreateBatchRequest(content, options);
return ClientResult.FromResponse(Pipeline.ProcessMessage(message, options));
PipelineResponse response = Pipeline.ProcessMessage(message, options);

using JsonDocument doc = JsonDocument.Parse(response.Content);
string batchId = doc.RootElement.GetProperty("id"u8).GetString();
string status = doc.RootElement.GetProperty("status"u8).GetString();

CreateBatchOperation operation = new(Pipeline, _endpoint, batchId, status, response);
return operation.WaitUntil(waitUntilCompleted, options);
}

public override async Task<ClientResult> GetBatchesAsync(string after, int? limit, RequestOptions options)
Expand All @@ -37,70 +52,19 @@ public override ClientResult GetBatches(string after, int? limit, RequestOptions
return ClientResult.FromResponse(Pipeline.ProcessMessage(message, options));
}

public override async Task<ClientResult> GetBatchAsync(string batchId, RequestOptions options)
internal override async Task<ClientResult> GetBatchAsync(string batchId, RequestOptions options)
{
Argument.AssertNotNullOrEmpty(batchId, nameof(batchId));

using PipelineMessage message = CreateRetrieveBatchRequest(batchId, options);
return ClientResult.FromResponse(await Pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false));
}

public override ClientResult GetBatch(string batchId, RequestOptions options)
internal override ClientResult GetBatch(string batchId, RequestOptions options)
{
Argument.AssertNotNullOrEmpty(batchId, nameof(batchId));

using PipelineMessage message = CreateRetrieveBatchRequest(batchId, options);
return ClientResult.FromResponse(Pipeline.ProcessMessage(message, options));
}

public override async Task<ClientResult> CancelBatchAsync(string batchId, RequestOptions options)
{
Argument.AssertNotNullOrEmpty(batchId, nameof(batchId));

using PipelineMessage message = CreateCancelBatchRequest(batchId, options);
return ClientResult.FromResponse(await Pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false));
}

public override ClientResult CancelBatch(string batchId, RequestOptions options)
{
Argument.AssertNotNullOrEmpty(batchId, nameof(batchId));

using PipelineMessage message = CreateCancelBatchRequest(batchId, options);
return ClientResult.FromResponse(Pipeline.ProcessMessage(message, options));
}

private new PipelineMessage CreateCreateBatchRequest(BinaryContent content, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion, _deploymentName)
.WithMethod("POST")
.WithPath("batches")
.WithContent(content, "application/json")
.WithAccept("application/json")
.WithOptions(options)
.Build();

private new PipelineMessage CreateGetBatchesRequest(string after, int? limit, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion, _deploymentName)
.WithMethod("GET")
.WithPath("batches")
.WithOptionalQueryParameter("after", after)
.WithOptionalQueryParameter("limit", limit)
.WithAccept("application/json")
.WithOptions(options)
.Build();

private new PipelineMessage CreateRetrieveBatchRequest(string batchId, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion, _deploymentName)
.WithMethod("GET")
.WithPath("batches", batchId)
.WithAccept("application/json")
.WithOptions(options)
.Build();

private new PipelineMessage CreateCancelBatchRequest(string batchId, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion, _deploymentName)
.WithMethod("POST")
.WithPath("batches", batchId, "cancel")
.WithAccept("application/json")
.WithOptions(options)
.Build();
}
42 changes: 41 additions & 1 deletion .dotnet.azure/src/Custom/Batch/AzureBatchClient.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.ClientModel;
using System.ClientModel.Primitives;
using OpenAI.Batch;

namespace Azure.AI.OpenAI.Batch;

Expand Down Expand Up @@ -34,4 +34,44 @@ internal AzureBatchClient(

protected AzureBatchClient()
{ }

internal override CreateBatchOperation CreateCreateBatchOperation(string batchId, string status, PipelineResponse response)
{
return new AzureCreateBatchOperation(Pipeline, _endpoint, batchId, status, response, _deploymentName, _apiVersion);
}

internal override PipelineMessage CreateCreateBatchRequest(BinaryContent content, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion, _deploymentName)
.WithMethod("POST")
.WithPath("batches")
.WithContent(content, "application/json")
.WithAccept("application/json")
.WithOptions(options)
.Build();

internal override PipelineMessage CreateGetBatchesRequest(string after, int? limit, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion, _deploymentName)
.WithMethod("GET")
.WithPath("batches")
.WithOptionalQueryParameter("after", after)
.WithOptionalQueryParameter("limit", limit)
.WithAccept("application/json")
.WithOptions(options)
.Build();

internal override PipelineMessage CreateRetrieveBatchRequest(string batchId, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion, _deploymentName)
.WithMethod("GET")
.WithPath("batches", batchId)
.WithAccept("application/json")
.WithOptions(options)
.Build();

internal override PipelineMessage CreateCancelBatchRequest(string batchId, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion, _deploymentName)
.WithMethod("POST")
.WithPath("batches", batchId, "cancel")
.WithAccept("application/json")
.WithOptions(options)
.Build();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using System.ClientModel;
using System.ClientModel.Primitives;

#nullable enable

namespace Azure.AI.OpenAI.Batch;

/// <summary>
/// A long-running operation for executing a batch from an uploaded file of
/// requests.
/// </summary>
public partial class AzureCreateBatchOperation : CreateBatchOperation
{
private readonly ClientPipeline _pipeline;
private readonly Uri _endpoint;
private readonly string _batchId;

private readonly string _deploymentName;
private readonly string _apiVersion;

internal AzureCreateBatchOperation(
ClientPipeline pipeline,
Uri endpoint,
string batchId,
string status,
PipelineResponse response,
string deploymentName,
string apiVersion)
: base(pipeline, endpoint, batchId, status, response)
{
_pipeline = pipeline;
_endpoint = endpoint;
_batchId = batchId;
_deploymentName = deploymentName;
_apiVersion = apiVersion;
}

internal override PipelineMessage CreateRetrieveBatchRequest(string batchId, RequestOptions? options)
=> new AzureOpenAIPipelineMessageBuilder(_pipeline, _endpoint, _apiVersion, _deploymentName)
.WithMethod("GET")
.WithPath("batches", batchId)
.WithAccept("application/json")
.WithOptions(options)
.Build();

internal override PipelineMessage CreateCancelBatchRequest(string batchId, RequestOptions? options)
=> new AzureOpenAIPipelineMessageBuilder(_pipeline, _endpoint, _apiVersion, _deploymentName)
.WithMethod("POST")
.WithPath("batches", batchId, "cancel")
.WithAccept("application/json")
.WithOptions(options)
.Build();

private static PipelineMessageClassifier? _pipelineMessageClassifier200;
private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 });
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Diagnostics.CodeAnalysis;

#nullable enable

namespace Azure.AI.OpenAI.FineTuning;

/// <summary>
/// A long-running operation for creating a new model from a given dataset.
/// </summary>
public class AzureCreateJobOperation : CreateJobOperation
chschrae marked this conversation as resolved.
Show resolved Hide resolved
{
private readonly PipelineMessageClassifier DeleteJobClassifier = PipelineMessageClassifier.Create(stackalloc ushort[] { 204 });
private readonly ClientPipeline _pipeline;
private readonly Uri _endpoint;
private readonly string _jobId;

private readonly string _apiVersion;

internal AzureCreateJobOperation(
ClientPipeline pipeline,
Uri endpoint,
string jobId,
string status,
PipelineResponse response,
string apiVersion)
: base(pipeline, endpoint, jobId, status, response)
{
_pipeline = pipeline;
_endpoint = endpoint;
_jobId = jobId;
_apiVersion = apiVersion;
}

[Experimental("AOAI001")]
public virtual ClientResult DeleteJob(string fineTuningJobId, RequestOptions? options)
chschrae marked this conversation as resolved.
Show resolved Hide resolved
{
using PipelineMessage message = CreateDeleteJobRequestMessage(fineTuningJobId, options);
return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options));
}

[Experimental("AOAI001")]
public virtual async Task<ClientResult> DeleteJobAsync(string fineTuningJobId, RequestOptions? options)
{
using PipelineMessage message = CreateDeleteJobRequestMessage(fineTuningJobId, options);
PipelineResponse response = await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false);
return ClientResult.FromResponse(response);
}

private PipelineMessage CreateDeleteJobRequestMessage(string fineTuningJobId, RequestOptions? options)
=> new AzureOpenAIPipelineMessageBuilder(_pipeline, _endpoint, _apiVersion)
.WithMethod("DELETE")
.WithPath("fine_tuning", "jobs", fineTuningJobId)
.WithAccept("application/json")
.WithClassifier(DeleteJobClassifier)
.WithOptions(options)
.Build();

internal override PipelineMessage CreateRetrieveFineTuningJobRequest(string fineTuningJobId, RequestOptions? options)
=> new AzureOpenAIPipelineMessageBuilder(_pipeline, _endpoint, _apiVersion)
.WithMethod("GET")
.WithPath("fine_tuning", "jobs", fineTuningJobId)
.WithAccept("application/json")
.WithOptions(options)
.Build();

internal override PipelineMessage CreateCancelFineTuningJobRequest(string fineTuningJobId, RequestOptions? options)
=> new AzureOpenAIPipelineMessageBuilder(_pipeline, _endpoint, _apiVersion)
.WithMethod("POST")
.WithPath("fine_tuning", "jobs", fineTuningJobId, "cancel")
.WithAccept("application/json")
.WithOptions(options)
.Build();

internal override PipelineMessage CreateGetFineTuningJobCheckpointsRequest(string fineTuningJobId, string after, int? limit, RequestOptions? options)
=> new AzureOpenAIPipelineMessageBuilder(_pipeline, _endpoint, _apiVersion)
.WithMethod("GET")
.WithPath("fine_tuning", "jobs", fineTuningJobId, "checkpoints")
.WithOptionalQueryParameter("after", after)
.WithOptionalQueryParameter("limit", limit)
.WithAccept("application/json")
.WithOptions(options)
.Build();

internal override PipelineMessage CreateGetFineTuningEventsRequest(string fineTuningJobId, string after, int? limit, RequestOptions? options)
=> new AzureOpenAIPipelineMessageBuilder(_pipeline, _endpoint, _apiVersion)
.WithMethod("GET")
.WithPath("fine_tuning", "jobs", fineTuningJobId, "events")
.WithOptionalQueryParameter("after", after)
.WithOptionalQueryParameter("limit", limit)
.WithAccept("application/json")
.WithOptions(options)
.Build();
}
Loading