Skip to content

Commit

Permalink
[Vertex AI] Add ImageGenerationRequest for Imagen (#14225)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Dec 9, 2024
1 parent c2888bd commit d1ea5fa
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 0 deletions.
3 changes: 3 additions & 0 deletions FirebaseVertexAI/Sources/GenerativeAIRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,6 @@ public struct RequestOptions {
self.timeout = timeout
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension RequestOptions: Equatable {}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ struct ImageGenerationInstance {
let prompt: String
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationInstance: Equatable {}

// MARK: - Codable Conformance

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ struct ImageGenerationOutputOptions {
let compressionQuality: Int?
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationOutputOptions: Equatable {}

// MARK: - Codable Conformance

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ struct ImageGenerationParameters {
let includeResponsibleAIFilterReason: Bool?
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationParameters: Equatable {}

// MARK: - Codable Conformance

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct ImageGenerationRequest<ImageType: ImagenImageRepresentable> {
let model: String
let options: RequestOptions
let instances: [ImageGenerationInstance]
let parameters: ImageGenerationParameters

init(model: String, options: RequestOptions, instances: [ImageGenerationInstance],
parameters: ImageGenerationParameters) {
self.model = model
self.options = options
self.instances = instances
self.parameters = parameters
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationRequest: GenerativeAIRequest where ImageType: Decodable {
typealias Response = ImageGenerationResponse<ImageType>

var url: URL {
return URL(string: "\(Constants.baseURL)/\(options.apiVersion)/\(model):predict")!
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationRequest: Encodable {
enum CodingKeys: CodingKey {
case instances
case parameters
}

func encode(to encoder: any Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(instances, forKey: .instances)
try container.encode(parameters, forKey: .parameters)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import XCTest

@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class ImageGenerationRequestTests: XCTestCase {
let encoder = JSONEncoder()
let requestOptions = RequestOptions(timeout: 30.0)
let modelName = "test-model-name"
let sampleCount = 4
let aspectRatio = "16:9"
let safetyFilterLevel = "block_low_and_above"
let includeResponsibleAIFilterReason = true
lazy var parameters = ImageGenerationParameters(
sampleCount: sampleCount,
storageURI: nil,
seed: nil,
negativePrompt: nil,
aspectRatio: aspectRatio,
safetyFilterLevel: safetyFilterLevel,
personGeneration: nil,
outputOptions: nil,
addWatermark: nil,
includeResponsibleAIFilterReason: includeResponsibleAIFilterReason
)

let instance = ImageGenerationInstance(prompt: "test-prompt")

override func setUp() {
encoder.outputFormatting = [.sortedKeys, .prettyPrinted, .withoutEscapingSlashes]
}

func testInitializeRequest_inlineDataImage() throws {
let request = ImageGenerationRequest<ImagenInlineDataImage>(
model: modelName,
options: requestOptions,
instances: [instance],
parameters: parameters
)

XCTAssertEqual(request.model, modelName)
XCTAssertEqual(request.options, requestOptions)
XCTAssertEqual(request.instances, [instance])
XCTAssertEqual(request.parameters, parameters)
XCTAssertEqual(
request.url,
URL(string: "\(Constants.baseURL)/\(requestOptions.apiVersion)/\(modelName):predict")
)
}

func testInitializeRequest_fileDataImage() throws {
let request = ImageGenerationRequest<ImagenFileDataImage>(
model: modelName,
options: requestOptions,
instances: [instance],
parameters: parameters
)

XCTAssertEqual(request.model, modelName)
XCTAssertEqual(request.options, requestOptions)
XCTAssertEqual(request.instances, [instance])
XCTAssertEqual(request.parameters, parameters)
XCTAssertEqual(
request.url,
URL(string: "\(Constants.baseURL)/\(requestOptions.apiVersion)/\(modelName):predict")
)
}

// MARK: - Encoding Tests

func testEncodeRequest_inlineDataImage() throws {
let request = ImageGenerationRequest<ImagenInlineDataImage>(
model: modelName,
options: RequestOptions(),
instances: [instance],
parameters: parameters
)

let jsonData = try encoder.encode(request)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"instances" : [
{
"prompt" : "\(instance.prompt)"
}
],
"parameters" : {
"aspectRatio" : "\(aspectRatio)",
"includeRaiReason" : \(includeResponsibleAIFilterReason),
"safetySetting" : "\(safetyFilterLevel)",
"sampleCount" : \(sampleCount)
}
}
""")
}

func testEncodeRequest_fileDataImage() throws {
let request = ImageGenerationRequest<ImagenFileDataImage>(
model: modelName,
options: RequestOptions(),
instances: [instance],
parameters: parameters
)

let jsonData = try encoder.encode(request)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"instances" : [
{
"prompt" : "\(instance.prompt)"
}
],
"parameters" : {
"aspectRatio" : "\(aspectRatio)",
"includeRaiReason" : \(includeResponsibleAIFilterReason),
"safetySetting" : "\(safetyFilterLevel)",
"sampleCount" : \(sampleCount)
}
}
""")
}
}

0 comments on commit d1ea5fa

Please sign in to comment.