Skip to content

Commit

Permalink
Add support for deployments.get endpoint (#74)
Browse files Browse the repository at this point in the history
* Fix Codable implementation for Account

* Add support for deployments.get endpoint
  • Loading branch information
mattt authored Feb 19, 2024
1 parent 999217f commit 84a1b75
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 29 deletions.
13 changes: 12 additions & 1 deletion Sources/Replicate/Account.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import struct Foundation.URL

/// A Replicate account.
public struct Account: Hashable, Codable {
public struct Account: Hashable {
/// The acount type.
public enum AccountType: String, CaseIterable, Hashable, Codable {
/// A user.
Expand Down Expand Up @@ -42,6 +42,17 @@ extension Account: CustomStringConvertible {
}
}

// MARK: - Codable

extension Account: Codable {
public enum CodingKeys: String, CodingKey {
case type
case username
case name
case githubURL = "github_url"
}
}

extension Account.AccountType: CustomStringConvertible {
public var description: String {
return self.rawValue
Expand Down
15 changes: 15 additions & 0 deletions Sources/Replicate/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,21 @@ public class Client {

// MARK: -

/// Get a deployment
///
/// - Parameters:
/// - id: The deployment identifier, comprising
/// the name of the user or organization that owns the deployment and
/// the name of the deployment.
/// For example, "replicate/my-app-image-generator".
public func getDeployment(_ id: Deployment.ID)
async throws -> Deployment
{
return try await fetch(.get, "deployments/\(id)")
}

// MARK: -

private enum Method: String, Hashable {
case get = "GET"
case post = "POST"
Expand Down
107 changes: 79 additions & 28 deletions Sources/Replicate/Deployment.swift
Original file line number Diff line number Diff line change
@@ -1,44 +1,95 @@
import struct Foundation.Date

/// A deployment of a model on Replicate.
public enum Deployment {
/// A deployment identifier.
public struct ID: Hashable {
/// The owner of the deployment.
public let owner: String

/// The name of the deployment.
public let name: String
public struct Deployment: Hashable {
/// The owner of the deployment.
public let owner: String

/// The name of the deployment.
public let name: String

/// A release of a deployment.
public struct Release: Hashable {
/// The release number.
let number: Int

/// The model.
let model: Model.ID

/// The model version.
let version: Model.Version.ID

/// The time at which the release was created.
let createdAt: Date

/// The account that created the release
let createdBy: Account

/// The configuration of a deployment.
public struct Configuration: Hashable {
/// The configured hardware SKU.
public let hardware: Hardware.ID

/// A scaling configuration for a deployment.
public struct Scaling: Hashable {
/// The maximum number of instances.
public let maxInstances: Int

/// The minimum number of instances.
public let minInstances: Int
}

/// The scaling configuration for the deployment.
public let scaling: Scaling
}

/// The deployment configuration.
public let configuration: Configuration
}

public let currentRelease: Release?
}

// MARK: - CustomStringConvertible
// MARK: - Identifiable

extension Deployment.ID: CustomStringConvertible {
public var description: String {
return "\(owner)/\(name)"
}
extension Deployment: Identifiable {
public typealias ID = String

/// The ID of the model.
public var id: ID { "\(owner)/\(name)" }
}

// MARK: - ExpressibleByStringLiteral
// MARK: - Codable

extension Deployment.ID: ExpressibleByStringLiteral {
public init(stringLiteral value: StringLiteralType) {
let components = value.split(separator: "/")
guard components.count == 2 else { fatalError("Invalid deployment ID: \(value)") }
self.init(owner: String(components[0]), name: String(components[1]))
extension Deployment: Codable {
public enum CodingKeys: String, CodingKey {
case owner
case name
case currentRelease = "current_release"
}
}

// MARK: - Codable
extension Deployment.Release: Codable {
public enum CodingKeys: String, CodingKey {
case number
case model
case version
case createdAt = "created_at"
case createdBy = "created_by"
case configuration
}
}

extension Deployment.ID: Codable {
public init(from decoder: Decoder) throws {
let container = try decoder.singleValueContainer()
let value = try container.decode(String.self)
self.init(stringLiteral: value)
extension Deployment.Release.Configuration: Codable {
public enum CodingKeys: String, CodingKey {
case hardware
case scaling
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
try container.encode(description)
extension Deployment.Release.Configuration.Scaling: Codable {
public enum CodingKeys: String, CodingKey {
case minInstances = "min_instances"
case maxInstances = "max_instances"
}
}
18 changes: 18 additions & 0 deletions Tests/ReplicateTests/ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,24 @@ final class ClientTests: XCTestCase {
XCTAssertEqual(account.type, .organization)
XCTAssertEqual(account.username, "replicate")
XCTAssertEqual(account.name, "Replicate")
XCTAssertEqual(account.githubURL?.absoluteString, "https://github.com/replicate")
}

func testGetDeployment() async throws {
let deployment = try await client.getDeployment("replicate/my-app-image-generator")
XCTAssertEqual(deployment.owner, "replicate")
XCTAssertEqual(deployment.name, "my-app-image-generator")
XCTAssertEqual(deployment.currentRelease?.number, 1)
XCTAssertEqual(deployment.currentRelease?.model, "stability-ai/sdxl")
XCTAssertEqual(deployment.currentRelease?.version, "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf")
XCTAssertEqual(deployment.currentRelease!.createdAt.timeIntervalSinceReferenceDate, 729707577.01, accuracy: 1)
XCTAssertEqual(deployment.currentRelease?.createdBy.type, .organization)
XCTAssertEqual(deployment.currentRelease?.createdBy.username, "replicate")
XCTAssertEqual(deployment.currentRelease?.createdBy.name, "Replicate, Inc.")
XCTAssertEqual(deployment.currentRelease?.createdBy.githubURL?.absoluteString, "https://github.com/replicate")
XCTAssertEqual(deployment.currentRelease?.configuration.hardware, "gpu-t4")
XCTAssertEqual(deployment.currentRelease?.configuration.scaling.minInstances, 1)
XCTAssertEqual(deployment.currentRelease?.configuration.scaling.maxInstances, 5)
}

func testCustomBaseURL() async throws {
Expand Down
27 changes: 27 additions & 0 deletions Tests/ReplicateTests/Helpers/MockURLProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,33 @@ class MockURLProtocol: URLProtocol {
"metrics": {}
}
"""#
case ("GET", "https://api.replicate.com/v1/deployments/replicate/my-app-image-generator"?):
statusCode = 200
json = #"""
{
"owner": "replicate",
"name": "my-app-image-generator",
"current_release": {
"number": 1,
"model": "stability-ai/sdxl",
"version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
"created_at": "2024-02-15T16:32:57.018467Z",
"created_by": {
"type": "organization",
"username": "replicate",
"name": "Replicate, Inc.",
"github_url": "https://github.com/replicate",
},
"configuration": {
"hardware": "gpu-t4",
"scaling": {
"min_instances": 1,
"max_instances": 5
}
}
}
}
"""#
default:
client?.urlProtocol(self, didFailWithError: URLError(.badURL))
return
Expand Down

0 comments on commit 84a1b75

Please sign in to comment.