From d1ea5fa30aecf22d94e37e2231f7fc9b28e2abc5 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Fri, 6 Dec 2024 17:50:56 -0500 Subject: [PATCH] [Vertex AI] Add `ImageGenerationRequest` for Imagen (#14225) --- .../Sources/GenerativeAIRequest.swift | 3 + .../Imagen/ImageGenerationInstance.swift | 3 + .../Imagen/ImageGenerationOutputOptions.swift | 3 + .../Imagen/ImageGenerationParameters.swift | 3 + .../Imagen/ImageGenerationRequest.swift | 54 +++++++ .../Imagen/ImageGenerationRequestTests.swift | 140 ++++++++++++++++++ 6 files changed, 206 insertions(+) create mode 100644 FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationRequest.swift create mode 100644 FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationRequestTests.swift diff --git a/FirebaseVertexAI/Sources/GenerativeAIRequest.swift b/FirebaseVertexAI/Sources/GenerativeAIRequest.swift index b792830120e..9b3b7330703 100644 --- a/FirebaseVertexAI/Sources/GenerativeAIRequest.swift +++ b/FirebaseVertexAI/Sources/GenerativeAIRequest.swift @@ -41,3 +41,6 @@ public struct RequestOptions { self.timeout = timeout } } + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension RequestOptions: Equatable {} diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationInstance.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationInstance.swift index 6025760fce8..c1d853643cc 100644 --- a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationInstance.swift +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationInstance.swift @@ -17,6 +17,9 @@ struct ImageGenerationInstance { let prompt: String } +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImageGenerationInstance: Equatable {} + // MARK: - Codable Conformance @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationOutputOptions.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationOutputOptions.swift index 6ac24776116..0b678187ee2 100644 --- a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationOutputOptions.swift +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationOutputOptions.swift @@ -20,6 +20,9 @@ struct ImageGenerationOutputOptions { let compressionQuality: Int? } +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImageGenerationOutputOptions: Equatable {} + // MARK: - Codable Conformance @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationParameters.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationParameters.swift index 6e5380666cd..9ab6641862c 100644 --- a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationParameters.swift +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationParameters.swift @@ -26,6 +26,9 @@ struct ImageGenerationParameters { let includeResponsibleAIFilterReason: Bool? } +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImageGenerationParameters: Equatable {} + // MARK: - Codable Conformance @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationRequest.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationRequest.swift new file mode 100644 index 00000000000..da972b4403b --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationRequest.swift @@ -0,0 +1,54 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct ImageGenerationRequest { + let model: String + let options: RequestOptions + let instances: [ImageGenerationInstance] + let parameters: ImageGenerationParameters + + init(model: String, options: RequestOptions, instances: [ImageGenerationInstance], + parameters: ImageGenerationParameters) { + self.model = model + self.options = options + self.instances = instances + self.parameters = parameters + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImageGenerationRequest: GenerativeAIRequest where ImageType: Decodable { + typealias Response = ImageGenerationResponse + + var url: URL { + return URL(string: "\(Constants.baseURL)/\(options.apiVersion)/\(model):predict")! + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImageGenerationRequest: Encodable { + enum CodingKeys: CodingKey { + case instances + case parameters + } + + func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(instances, forKey: .instances) + try container.encode(parameters, forKey: .parameters) + } +} diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationRequestTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationRequestTests.swift new file mode 100644 index 00000000000..90ca9e7c25d --- /dev/null +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationRequestTests.swift @@ -0,0 +1,140 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +@testable import FirebaseVertexAI + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class ImageGenerationRequestTests: XCTestCase { + let encoder = JSONEncoder() + let requestOptions = RequestOptions(timeout: 30.0) + let modelName = "test-model-name" + let sampleCount = 4 + let aspectRatio = "16:9" + let safetyFilterLevel = "block_low_and_above" + let includeResponsibleAIFilterReason = true + lazy var parameters = ImageGenerationParameters( + sampleCount: sampleCount, + storageURI: nil, + seed: nil, + negativePrompt: nil, + aspectRatio: aspectRatio, + safetyFilterLevel: safetyFilterLevel, + personGeneration: nil, + outputOptions: nil, + addWatermark: nil, + includeResponsibleAIFilterReason: includeResponsibleAIFilterReason + ) + + let instance = ImageGenerationInstance(prompt: "test-prompt") + + override func setUp() { + encoder.outputFormatting = [.sortedKeys, .prettyPrinted, .withoutEscapingSlashes] + } + + func testInitializeRequest_inlineDataImage() throws { + let request = ImageGenerationRequest( + model: modelName, + options: requestOptions, + instances: [instance], + parameters: parameters + ) + + XCTAssertEqual(request.model, modelName) + XCTAssertEqual(request.options, requestOptions) + XCTAssertEqual(request.instances, [instance]) + XCTAssertEqual(request.parameters, parameters) + XCTAssertEqual( + request.url, + URL(string: "\(Constants.baseURL)/\(requestOptions.apiVersion)/\(modelName):predict") + ) + } + + func testInitializeRequest_fileDataImage() throws { + let request = ImageGenerationRequest( + model: modelName, + options: requestOptions, + instances: [instance], + parameters: parameters + ) + + XCTAssertEqual(request.model, modelName) + XCTAssertEqual(request.options, requestOptions) + XCTAssertEqual(request.instances, [instance]) + XCTAssertEqual(request.parameters, parameters) + XCTAssertEqual( + request.url, + URL(string: "\(Constants.baseURL)/\(requestOptions.apiVersion)/\(modelName):predict") + ) + } + + // MARK: - Encoding Tests + + func testEncodeRequest_inlineDataImage() throws { + let request = ImageGenerationRequest( + model: modelName, + options: RequestOptions(), + instances: [instance], + parameters: parameters + ) + + let jsonData = try encoder.encode(request) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "instances" : [ + { + "prompt" : "\(instance.prompt)" + } + ], + "parameters" : { + "aspectRatio" : "\(aspectRatio)", + "includeRaiReason" : \(includeResponsibleAIFilterReason), + "safetySetting" : "\(safetyFilterLevel)", + "sampleCount" : \(sampleCount) + } + } + """) + } + + func testEncodeRequest_fileDataImage() throws { + let request = ImageGenerationRequest( + model: modelName, + options: RequestOptions(), + instances: [instance], + parameters: parameters + ) + + let jsonData = try encoder.encode(request) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "instances" : [ + { + "prompt" : "\(instance.prompt)" + } + ], + "parameters" : { + "aspectRatio" : "\(aspectRatio)", + "includeRaiReason" : \(includeResponsibleAIFilterReason), + "safetySetting" : "\(safetyFilterLevel)", + "sampleCount" : \(sampleCount) + } + } + """) + } +}