-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Vertex AI] Add
ImageGenerationRequest
for Imagen (#14225)
- Loading branch information
1 parent
c2888bd
commit d1ea5fa
Showing
6 changed files
with
206 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
54 changes: 54 additions & 0 deletions
54
FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationRequest.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import Foundation | ||
|
||
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) | ||
struct ImageGenerationRequest<ImageType: ImagenImageRepresentable> { | ||
let model: String | ||
let options: RequestOptions | ||
let instances: [ImageGenerationInstance] | ||
let parameters: ImageGenerationParameters | ||
|
||
init(model: String, options: RequestOptions, instances: [ImageGenerationInstance], | ||
parameters: ImageGenerationParameters) { | ||
self.model = model | ||
self.options = options | ||
self.instances = instances | ||
self.parameters = parameters | ||
} | ||
} | ||
|
||
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) | ||
extension ImageGenerationRequest: GenerativeAIRequest where ImageType: Decodable { | ||
typealias Response = ImageGenerationResponse<ImageType> | ||
|
||
var url: URL { | ||
return URL(string: "\(Constants.baseURL)/\(options.apiVersion)/\(model):predict")! | ||
} | ||
} | ||
|
||
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) | ||
extension ImageGenerationRequest: Encodable { | ||
enum CodingKeys: CodingKey { | ||
case instances | ||
case parameters | ||
} | ||
|
||
func encode(to encoder: any Encoder) throws { | ||
var container = encoder.container(keyedBy: CodingKeys.self) | ||
try container.encode(instances, forKey: .instances) | ||
try container.encode(parameters, forKey: .parameters) | ||
} | ||
} |
140 changes: 140 additions & 0 deletions
140
FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationRequestTests.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import XCTest | ||
|
||
@testable import FirebaseVertexAI | ||
|
||
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) | ||
final class ImageGenerationRequestTests: XCTestCase { | ||
let encoder = JSONEncoder() | ||
let requestOptions = RequestOptions(timeout: 30.0) | ||
let modelName = "test-model-name" | ||
let sampleCount = 4 | ||
let aspectRatio = "16:9" | ||
let safetyFilterLevel = "block_low_and_above" | ||
let includeResponsibleAIFilterReason = true | ||
lazy var parameters = ImageGenerationParameters( | ||
sampleCount: sampleCount, | ||
storageURI: nil, | ||
seed: nil, | ||
negativePrompt: nil, | ||
aspectRatio: aspectRatio, | ||
safetyFilterLevel: safetyFilterLevel, | ||
personGeneration: nil, | ||
outputOptions: nil, | ||
addWatermark: nil, | ||
includeResponsibleAIFilterReason: includeResponsibleAIFilterReason | ||
) | ||
|
||
let instance = ImageGenerationInstance(prompt: "test-prompt") | ||
|
||
override func setUp() { | ||
encoder.outputFormatting = [.sortedKeys, .prettyPrinted, .withoutEscapingSlashes] | ||
} | ||
|
||
func testInitializeRequest_inlineDataImage() throws { | ||
let request = ImageGenerationRequest<ImagenInlineDataImage>( | ||
model: modelName, | ||
options: requestOptions, | ||
instances: [instance], | ||
parameters: parameters | ||
) | ||
|
||
XCTAssertEqual(request.model, modelName) | ||
XCTAssertEqual(request.options, requestOptions) | ||
XCTAssertEqual(request.instances, [instance]) | ||
XCTAssertEqual(request.parameters, parameters) | ||
XCTAssertEqual( | ||
request.url, | ||
URL(string: "\(Constants.baseURL)/\(requestOptions.apiVersion)/\(modelName):predict") | ||
) | ||
} | ||
|
||
func testInitializeRequest_fileDataImage() throws { | ||
let request = ImageGenerationRequest<ImagenFileDataImage>( | ||
model: modelName, | ||
options: requestOptions, | ||
instances: [instance], | ||
parameters: parameters | ||
) | ||
|
||
XCTAssertEqual(request.model, modelName) | ||
XCTAssertEqual(request.options, requestOptions) | ||
XCTAssertEqual(request.instances, [instance]) | ||
XCTAssertEqual(request.parameters, parameters) | ||
XCTAssertEqual( | ||
request.url, | ||
URL(string: "\(Constants.baseURL)/\(requestOptions.apiVersion)/\(modelName):predict") | ||
) | ||
} | ||
|
||
// MARK: - Encoding Tests | ||
|
||
func testEncodeRequest_inlineDataImage() throws { | ||
let request = ImageGenerationRequest<ImagenInlineDataImage>( | ||
model: modelName, | ||
options: RequestOptions(), | ||
instances: [instance], | ||
parameters: parameters | ||
) | ||
|
||
let jsonData = try encoder.encode(request) | ||
|
||
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) | ||
XCTAssertEqual(json, """ | ||
{ | ||
"instances" : [ | ||
{ | ||
"prompt" : "\(instance.prompt)" | ||
} | ||
], | ||
"parameters" : { | ||
"aspectRatio" : "\(aspectRatio)", | ||
"includeRaiReason" : \(includeResponsibleAIFilterReason), | ||
"safetySetting" : "\(safetyFilterLevel)", | ||
"sampleCount" : \(sampleCount) | ||
} | ||
} | ||
""") | ||
} | ||
|
||
func testEncodeRequest_fileDataImage() throws { | ||
let request = ImageGenerationRequest<ImagenFileDataImage>( | ||
model: modelName, | ||
options: RequestOptions(), | ||
instances: [instance], | ||
parameters: parameters | ||
) | ||
|
||
let jsonData = try encoder.encode(request) | ||
|
||
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) | ||
XCTAssertEqual(json, """ | ||
{ | ||
"instances" : [ | ||
{ | ||
"prompt" : "\(instance.prompt)" | ||
} | ||
], | ||
"parameters" : { | ||
"aspectRatio" : "\(aspectRatio)", | ||
"includeRaiReason" : \(includeResponsibleAIFilterReason), | ||
"safetySetting" : "\(safetyFilterLevel)", | ||
"sampleCount" : \(sampleCount) | ||
} | ||
} | ||
""") | ||
} | ||
} |