From abdb9c4de8beb0188aaaae8a9b817ae28cd27003 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Mon, 6 Jan 2025 17:59:41 -0500 Subject: [PATCH 1/2] [Vertex AI] Refactor `ImagenSafetySettings` --- .../Types/Public/Imagen/ImagenModel.swift | 4 +- .../Imagen/ImagenPersonFilterLevel.swift | 28 +++++++++++ .../Imagen/ImagenSafetyFilterLevel.swift | 30 ++++++++++++ .../Public/Imagen/ImagenSafetySettings.swift | 48 ++----------------- .../ImageGenerationParametersTests.swift | 24 ++++------ 5 files changed, 75 insertions(+), 59 deletions(-) create mode 100644 FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenPersonFilterLevel.swift create mode 100644 FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenSafetyFilterLevel.swift diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift index 9d5e7887969..afd7d370448 100644 --- a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift @@ -98,7 +98,7 @@ public final class ImagenModel { negativePrompt: generationConfig?.negativePrompt, aspectRatio: generationConfig?.aspectRatio?.rawValue, safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue, - personGeneration: safetySettings?.personGeneration?.rawValue, + personGeneration: safetySettings?.personFilterLevel?.rawValue, outputOptions: generationConfig?.imageFormat.map { ImageGenerationOutputOptions( mimeType: $0.mimeType, @@ -106,7 +106,7 @@ public final class ImagenModel { ) }, addWatermark: generationConfig?.addWatermark, - includeResponsibleAIFilterReason: safetySettings?.includeFilterReason ?? true + includeResponsibleAIFilterReason: true ) } } diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenPersonFilterLevel.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenPersonFilterLevel.swift new file mode 100644 index 00000000000..b3eda03754d --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenPersonFilterLevel.swift @@ -0,0 +1,28 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct ImagenPersonFilterLevel: ProtoEnum { + enum Kind: String { + case blockAll = "dont_allow" + case allowAdult = "allow_adult" + case allowAll = "allow_all" + } + + public static let blockAll = ImagenPersonFilterLevel(kind: .blockAll) + public static let allowAdult = ImagenPersonFilterLevel(kind: .allowAdult) + public static let allowAll = ImagenPersonFilterLevel(kind: .allowAll) + + let rawValue: String +} diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenSafetyFilterLevel.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenSafetyFilterLevel.swift new file mode 100644 index 00000000000..32b8e4f1a02 --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenSafetyFilterLevel.swift @@ -0,0 +1,30 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct ImagenSafetyFilterLevel: ProtoEnum { + enum Kind: String { + case blockLowAndAbove = "block_low_and_above" + case blockMediumAndAbove = "block_medium_and_above" + case blockOnlyHigh = "block_only_high" + case blockNone = "block_none" + } + + public static let blockLowAndAbove = ImagenSafetyFilterLevel(kind: .blockLowAndAbove) + public static let blockMediumAndAbove = ImagenSafetyFilterLevel(kind: .blockMediumAndAbove) + public static let blockOnlyHigh = ImagenSafetyFilterLevel(kind: .blockOnlyHigh) + public static let blockNone = ImagenSafetyFilterLevel(kind: .blockNone) + + let rawValue: String +} diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenSafetySettings.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenSafetySettings.swift index 81b72890a0e..91a28891223 100644 --- a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenSafetySettings.swift +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenSafetySettings.swift @@ -16,50 +16,12 @@ import Foundation @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct ImagenSafetySettings { - let safetyFilterLevel: SafetyFilterLevel? - let includeFilterReason: Bool? - let personGeneration: PersonGeneration? + let safetyFilterLevel: ImagenSafetyFilterLevel? + let personFilterLevel: ImagenPersonFilterLevel? - public init(safetyFilterLevel: SafetyFilterLevel? = nil, includeFilterReason: Bool? = nil, - personGeneration: PersonGeneration? = nil) { + public init(safetyFilterLevel: ImagenSafetyFilterLevel? = nil, + personFilterLevel: ImagenPersonFilterLevel? = nil) { self.safetyFilterLevel = safetyFilterLevel - self.includeFilterReason = includeFilterReason - self.personGeneration = personGeneration - } -} - -@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public extension ImagenSafetySettings { - struct SafetyFilterLevel: ProtoEnum { - enum Kind: String { - case blockLowAndAbove = "block_low_and_above" - case blockMediumAndAbove = "block_medium_and_above" - case blockOnlyHigh = "block_only_high" - case blockNone = "block_none" - } - - public static let blockLowAndAbove = SafetyFilterLevel(kind: .blockLowAndAbove) - public static let blockMediumAndAbove = SafetyFilterLevel(kind: .blockMediumAndAbove) - public static let blockOnlyHigh = SafetyFilterLevel(kind: .blockOnlyHigh) - public static let blockNone = SafetyFilterLevel(kind: .blockNone) - - let rawValue: String - } -} - -@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public extension ImagenSafetySettings { - struct PersonGeneration: ProtoEnum { - enum Kind: String { - case blockAll = "dont_allow" - case allowAdult = "allow_adult" - case allowAll = "allow_all" - } - - public static let blockAll = PersonGeneration(kind: .blockAll) - public static let allowAdult = PersonGeneration(kind: .allowAdult) - public static let allowAll = PersonGeneration(kind: .allowAll) - - let rawValue: String + self.personFilterLevel = personFilterLevel } } diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift index f50f1646c8d..6b313afcd8c 100644 --- a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift @@ -111,13 +111,11 @@ final class ImageGenerationParametersTests: XCTestCase { } func testDefaultParameters_includeSafetySettings() throws { - let safetyFilterLevel = ImagenSafetySettings.SafetyFilterLevel.blockOnlyHigh - let personGeneration = ImagenSafetySettings.PersonGeneration.allowAll - let includeFilterReason = false + let safetyFilterLevel = ImagenSafetyFilterLevel.blockOnlyHigh + let personFilterLevel = ImagenPersonFilterLevel.allowAll let safetySettings = ImagenSafetySettings( safetyFilterLevel: safetyFilterLevel, - includeFilterReason: includeFilterReason, - personGeneration: personGeneration + personFilterLevel: personFilterLevel ) let expectedParameters = ImageGenerationParameters( sampleCount: 1, @@ -125,10 +123,10 @@ final class ImageGenerationParametersTests: XCTestCase { negativePrompt: nil, aspectRatio: nil, safetyFilterLevel: safetyFilterLevel.rawValue, - personGeneration: personGeneration.rawValue, + personGeneration: personFilterLevel.rawValue, outputOptions: nil, addWatermark: nil, - includeResponsibleAIFilterReason: includeFilterReason + includeResponsibleAIFilterReason: true ) let parameters = ImagenModel.imageGenerationParameters( @@ -156,13 +154,11 @@ final class ImageGenerationParametersTests: XCTestCase { imageFormat: imageFormat, addWatermark: addWatermark ) - let safetyFilterLevel = ImagenSafetySettings.SafetyFilterLevel.blockNone - let personGeneration = ImagenSafetySettings.PersonGeneration.blockAll - let includeFilterReason = false + let safetyFilterLevel = ImagenSafetyFilterLevel.blockNone + let personFilterLevel = ImagenPersonFilterLevel.blockAll let safetySettings = ImagenSafetySettings( safetyFilterLevel: safetyFilterLevel, - includeFilterReason: includeFilterReason, - personGeneration: personGeneration + personFilterLevel: personFilterLevel ) let expectedParameters = ImageGenerationParameters( sampleCount: sampleCount, @@ -170,13 +166,13 @@ final class ImageGenerationParametersTests: XCTestCase { negativePrompt: negativePrompt, aspectRatio: aspectRatio.rawValue, safetyFilterLevel: safetyFilterLevel.rawValue, - personGeneration: personGeneration.rawValue, + personGeneration: personFilterLevel.rawValue, outputOptions: ImageGenerationOutputOptions( mimeType: imageFormat.mimeType, compressionQuality: imageFormat.compressionQuality ), addWatermark: addWatermark, - includeResponsibleAIFilterReason: includeFilterReason + includeResponsibleAIFilterReason: true ) let parameters = ImagenModel.imageGenerationParameters( From 962468bae090e69e25383583f12236c00390534d Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Mon, 6 Jan 2025 18:20:28 -0500 Subject: [PATCH 2/2] Fix integration test --- .../Tests/TestApp/Tests/Integration/IntegrationTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift index 4e233ffee38..69ede59d802 100644 --- a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift +++ b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift @@ -65,7 +65,7 @@ final class IntegrationTests: XCTestCase { modelName: "imagen-3.0-fast-generate-001", safetySettings: ImagenSafetySettings( safetyFilterLevel: .blockLowAndAbove, - personGeneration: .blockAll + personFilterLevel: .blockAll ) )