Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for deployments.get endpoint #74

Merged
merged 2 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion Sources/Replicate/Account.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import struct Foundation.URL

/// A Replicate account.
public struct Account: Hashable, Codable {
public struct Account: Hashable {
/// The acount type.
public enum AccountType: String, CaseIterable, Hashable, Codable {
/// A user.
Expand Down Expand Up @@ -42,6 +42,17 @@ extension Account: CustomStringConvertible {
}
}

// MARK: - Codable

extension Account: Codable {
public enum CodingKeys: String, CodingKey {
case type
case username
case name
case githubURL = "github_url"
}
}

extension Account.AccountType: CustomStringConvertible {
public var description: String {
return self.rawValue
Expand Down
15 changes: 15 additions & 0 deletions Sources/Replicate/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,21 @@ public class Client {

// MARK: -

/// Get a deployment
///
/// - Parameters:
/// - id: The deployment identifier, comprising
/// the name of the user or organization that owns the deployment and
/// the name of the deployment.
/// For example, "replicate/my-app-image-generator".
public func getDeployment(_ id: Deployment.ID)
async throws -> Deployment
{
return try await fetch(.get, "deployments/\(id)")
}

// MARK: -

private enum Method: String, Hashable {
case get = "GET"
case post = "POST"
Expand Down
107 changes: 79 additions & 28 deletions Sources/Replicate/Deployment.swift
Original file line number Diff line number Diff line change
@@ -1,44 +1,95 @@
import struct Foundation.Date

/// A deployment of a model on Replicate.
public enum Deployment {
/// A deployment identifier.
public struct ID: Hashable {
/// The owner of the deployment.
public let owner: String

/// The name of the deployment.
public let name: String
public struct Deployment: Hashable {
/// The owner of the deployment.
public let owner: String

/// The name of the deployment.
public let name: String

/// A release of a deployment.
public struct Release: Hashable {
/// The release number.
let number: Int

/// The model.
let model: Model.ID

/// The model version.
let version: Model.Version.ID

/// The time at which the release was created.
let createdAt: Date

/// The account that created the release
let createdBy: Account

/// The configuration of a deployment.
public struct Configuration: Hashable {
/// The configured hardware SKU.
public let hardware: Hardware.ID

/// A scaling configuration for a deployment.
public struct Scaling: Hashable {
/// The maximum number of instances.
public let maxInstances: Int

/// The minimum number of instances.
public let minInstances: Int
}

/// The scaling configuration for the deployment.
public let scaling: Scaling
}

/// The deployment configuration.
public let configuration: Configuration
}

public let currentRelease: Release?
}

// MARK: - CustomStringConvertible
// MARK: - Identifiable

extension Deployment.ID: CustomStringConvertible {
public var description: String {
return "\(owner)/\(name)"
}
extension Deployment: Identifiable {
public typealias ID = String

/// The ID of the model.
public var id: ID { "\(owner)/\(name)" }
}

// MARK: - ExpressibleByStringLiteral
// MARK: - Codable

extension Deployment.ID: ExpressibleByStringLiteral {
public init(stringLiteral value: StringLiteralType) {
let components = value.split(separator: "/")
guard components.count == 2 else { fatalError("Invalid deployment ID: \(value)") }
self.init(owner: String(components[0]), name: String(components[1]))
extension Deployment: Codable {
public enum CodingKeys: String, CodingKey {
case owner
case name
case currentRelease = "current_release"
}
}

// MARK: - Codable
extension Deployment.Release: Codable {
public enum CodingKeys: String, CodingKey {
case number
case model
case version
case createdAt = "created_at"
case createdBy = "created_by"
case configuration
}
}

extension Deployment.ID: Codable {
public init(from decoder: Decoder) throws {
let container = try decoder.singleValueContainer()
let value = try container.decode(String.self)
self.init(stringLiteral: value)
extension Deployment.Release.Configuration: Codable {
public enum CodingKeys: String, CodingKey {
case hardware
case scaling
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
try container.encode(description)
extension Deployment.Release.Configuration.Scaling: Codable {
public enum CodingKeys: String, CodingKey {
case minInstances = "min_instances"
case maxInstances = "max_instances"
}
}
18 changes: 18 additions & 0 deletions Tests/ReplicateTests/ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,24 @@ final class ClientTests: XCTestCase {
XCTAssertEqual(account.type, .organization)
XCTAssertEqual(account.username, "replicate")
XCTAssertEqual(account.name, "Replicate")
XCTAssertEqual(account.githubURL?.absoluteString, "https://github.com/replicate")
}

func testGetDeployment() async throws {
let deployment = try await client.getDeployment("replicate/my-app-image-generator")
XCTAssertEqual(deployment.owner, "replicate")
XCTAssertEqual(deployment.name, "my-app-image-generator")
XCTAssertEqual(deployment.currentRelease?.number, 1)
XCTAssertEqual(deployment.currentRelease?.model, "stability-ai/sdxl")
XCTAssertEqual(deployment.currentRelease?.version, "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf")
XCTAssertEqual(deployment.currentRelease!.createdAt.timeIntervalSinceReferenceDate, 729707577.01, accuracy: 1)
XCTAssertEqual(deployment.currentRelease?.createdBy.type, .organization)
XCTAssertEqual(deployment.currentRelease?.createdBy.username, "replicate")
XCTAssertEqual(deployment.currentRelease?.createdBy.name, "Replicate, Inc.")
XCTAssertEqual(deployment.currentRelease?.createdBy.githubURL?.absoluteString, "https://github.com/replicate")
XCTAssertEqual(deployment.currentRelease?.configuration.hardware, "gpu-t4")
XCTAssertEqual(deployment.currentRelease?.configuration.scaling.minInstances, 1)
XCTAssertEqual(deployment.currentRelease?.configuration.scaling.maxInstances, 5)
}

func testCustomBaseURL() async throws {
Expand Down
27 changes: 27 additions & 0 deletions Tests/ReplicateTests/Helpers/MockURLProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,33 @@ class MockURLProtocol: URLProtocol {
"metrics": {}
}
"""#
case ("GET", "https://api.replicate.com/v1/deployments/replicate/my-app-image-generator"?):
statusCode = 200
json = #"""
{
"owner": "replicate",
"name": "my-app-image-generator",
"current_release": {
"number": 1,
"model": "stability-ai/sdxl",
"version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
"created_at": "2024-02-15T16:32:57.018467Z",
"created_by": {
"type": "organization",
"username": "replicate",
"name": "Replicate, Inc.",
"github_url": "https://github.com/replicate",
},
"configuration": {
"hardware": "gpu-t4",
"scaling": {
"min_instances": 1,
"max_instances": 5
}
}
}
}
"""#
default:
client?.urlProtocol(self, didFailWithError: URLError(.badURL))
return
Expand Down
Loading