Skip to content

Commit

Permalink
[Vertex AI] Refactor ImagenSafetySettings (#14307)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Jan 7, 2025
1 parent 1de7f6f commit 53d43d8
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ public final class ImagenModel {
negativePrompt: generationConfig?.negativePrompt,
aspectRatio: generationConfig?.aspectRatio?.rawValue,
safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue,
personGeneration: safetySettings?.personGeneration?.rawValue,
personGeneration: safetySettings?.personFilterLevel?.rawValue,
outputOptions: generationConfig?.imageFormat.map {
ImageGenerationOutputOptions(
mimeType: $0.mimeType,
compressionQuality: $0.compressionQuality
)
},
addWatermark: generationConfig?.addWatermark,
includeResponsibleAIFilterReason: safetySettings?.includeFilterReason ?? true
includeResponsibleAIFilterReason: true
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenPersonFilterLevel: ProtoEnum {
enum Kind: String {
case blockAll = "dont_allow"
case allowAdult = "allow_adult"
case allowAll = "allow_all"
}

public static let blockAll = ImagenPersonFilterLevel(kind: .blockAll)
public static let allowAdult = ImagenPersonFilterLevel(kind: .allowAdult)
public static let allowAll = ImagenPersonFilterLevel(kind: .allowAll)

let rawValue: String
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenSafetyFilterLevel: ProtoEnum {
enum Kind: String {
case blockLowAndAbove = "block_low_and_above"
case blockMediumAndAbove = "block_medium_and_above"
case blockOnlyHigh = "block_only_high"
case blockNone = "block_none"
}

public static let blockLowAndAbove = ImagenSafetyFilterLevel(kind: .blockLowAndAbove)
public static let blockMediumAndAbove = ImagenSafetyFilterLevel(kind: .blockMediumAndAbove)
public static let blockOnlyHigh = ImagenSafetyFilterLevel(kind: .blockOnlyHigh)
public static let blockNone = ImagenSafetyFilterLevel(kind: .blockNone)

let rawValue: String
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,50 +16,12 @@ import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenSafetySettings {
let safetyFilterLevel: SafetyFilterLevel?
let includeFilterReason: Bool?
let personGeneration: PersonGeneration?
let safetyFilterLevel: ImagenSafetyFilterLevel?
let personFilterLevel: ImagenPersonFilterLevel?

public init(safetyFilterLevel: SafetyFilterLevel? = nil, includeFilterReason: Bool? = nil,
personGeneration: PersonGeneration? = nil) {
public init(safetyFilterLevel: ImagenSafetyFilterLevel? = nil,
personFilterLevel: ImagenPersonFilterLevel? = nil) {
self.safetyFilterLevel = safetyFilterLevel
self.includeFilterReason = includeFilterReason
self.personGeneration = personGeneration
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public extension ImagenSafetySettings {
struct SafetyFilterLevel: ProtoEnum {
enum Kind: String {
case blockLowAndAbove = "block_low_and_above"
case blockMediumAndAbove = "block_medium_and_above"
case blockOnlyHigh = "block_only_high"
case blockNone = "block_none"
}

public static let blockLowAndAbove = SafetyFilterLevel(kind: .blockLowAndAbove)
public static let blockMediumAndAbove = SafetyFilterLevel(kind: .blockMediumAndAbove)
public static let blockOnlyHigh = SafetyFilterLevel(kind: .blockOnlyHigh)
public static let blockNone = SafetyFilterLevel(kind: .blockNone)

let rawValue: String
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public extension ImagenSafetySettings {
struct PersonGeneration: ProtoEnum {
enum Kind: String {
case blockAll = "dont_allow"
case allowAdult = "allow_adult"
case allowAll = "allow_all"
}

public static let blockAll = PersonGeneration(kind: .blockAll)
public static let allowAdult = PersonGeneration(kind: .allowAdult)
public static let allowAll = PersonGeneration(kind: .allowAll)

let rawValue: String
self.personFilterLevel = personFilterLevel
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ final class IntegrationTests: XCTestCase {
modelName: "imagen-3.0-fast-generate-001",
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockLowAndAbove,
personGeneration: .blockAll
personFilterLevel: .blockAll
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,22 @@ final class ImageGenerationParametersTests: XCTestCase {
}

func testDefaultParameters_includeSafetySettings() throws {
let safetyFilterLevel = ImagenSafetySettings.SafetyFilterLevel.blockOnlyHigh
let personGeneration = ImagenSafetySettings.PersonGeneration.allowAll
let includeFilterReason = false
let safetyFilterLevel = ImagenSafetyFilterLevel.blockOnlyHigh
let personFilterLevel = ImagenPersonFilterLevel.allowAll
let safetySettings = ImagenSafetySettings(
safetyFilterLevel: safetyFilterLevel,
includeFilterReason: includeFilterReason,
personGeneration: personGeneration
personFilterLevel: personFilterLevel
)
let expectedParameters = ImageGenerationParameters(
sampleCount: 1,
storageURI: nil,
negativePrompt: nil,
aspectRatio: nil,
safetyFilterLevel: safetyFilterLevel.rawValue,
personGeneration: personGeneration.rawValue,
personGeneration: personFilterLevel.rawValue,
outputOptions: nil,
addWatermark: nil,
includeResponsibleAIFilterReason: includeFilterReason
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
Expand Down Expand Up @@ -156,27 +154,25 @@ final class ImageGenerationParametersTests: XCTestCase {
imageFormat: imageFormat,
addWatermark: addWatermark
)
let safetyFilterLevel = ImagenSafetySettings.SafetyFilterLevel.blockNone
let personGeneration = ImagenSafetySettings.PersonGeneration.blockAll
let includeFilterReason = false
let safetyFilterLevel = ImagenSafetyFilterLevel.blockNone
let personFilterLevel = ImagenPersonFilterLevel.blockAll
let safetySettings = ImagenSafetySettings(
safetyFilterLevel: safetyFilterLevel,
includeFilterReason: includeFilterReason,
personGeneration: personGeneration
personFilterLevel: personFilterLevel
)
let expectedParameters = ImageGenerationParameters(
sampleCount: sampleCount,
storageURI: storageURI,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio.rawValue,
safetyFilterLevel: safetyFilterLevel.rawValue,
personGeneration: personGeneration.rawValue,
personGeneration: personFilterLevel.rawValue,
outputOptions: ImageGenerationOutputOptions(
mimeType: imageFormat.mimeType,
compressionQuality: imageFormat.compressionQuality
),
addWatermark: addWatermark,
includeResponsibleAIFilterReason: includeFilterReason
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
Expand Down

0 comments on commit 53d43d8

Please sign in to comment.