diff --git a/Sources/XMTP/Client.swift b/Sources/XMTP/Client.swift index ebcd0cd7..14fc7f37 100644 --- a/Sources/XMTP/Client.swift +++ b/Sources/XMTP/Client.swift @@ -9,6 +9,8 @@ import Foundation import web3 import XMTPRust +public typealias PreEventCallback = () async throws -> Void + public enum ClientError: Error { case creationError(String) } @@ -35,10 +37,18 @@ public struct ClientOptions { public var api = Api() public var codecs: [any ContentCodec] = [] - - public init(api: Api = Api(), codecs: [any ContentCodec] = []) { + + /// `preEnableIdentityCallback` will be called immediately before an Enable Identity wallet signature is requested from the user. + public var preEnableIdentityCallback: PreEventCallback? = nil + + /// `preCreateIdentityCallback` will be called immediately before a Create Identity wallet signature is requested from the user. + public var preCreateIdentityCallback: PreEventCallback? = nil + + public init(api: Api = Api(), codecs: [any ContentCodec] = [], preEnableIdentityCallback: PreEventCallback? = nil, preCreateIdentityCallback: PreEventCallback? = nil) { self.api = api self.codecs = codecs + self.preEnableIdentityCallback = preEnableIdentityCallback + self.preCreateIdentityCallback = preCreateIdentityCallback } } @@ -83,14 +93,14 @@ public final class Client: Sendable { secure: options.api.isSecure, rustClient: client ) - return try await create(account: account, apiClient: apiClient) + return try await create(account: account, apiClient: apiClient, options: options) } catch let error as RustString { throw ClientError.creationError(error.toString()) } } - static func create(account: SigningKey, apiClient: ApiClient) async throws -> Client { - let privateKeyBundleV1 = try await loadOrCreateKeys(for: account, apiClient: apiClient) + static func create(account: SigningKey, apiClient: ApiClient, options: ClientOptions? = nil) async throws -> Client { + let privateKeyBundleV1 = try await loadOrCreateKeys(for: account, apiClient: apiClient, options: options) let client = try Client(address: account.address, privateKeyBundleV1: privateKeyBundleV1, apiClient: apiClient) try await client.ensureUserContactPublished() @@ -98,9 +108,9 @@ public final class Client: Sendable { return client } - static func loadOrCreateKeys(for account: SigningKey, apiClient: ApiClient) async throws -> PrivateKeyBundleV1 { + static func loadOrCreateKeys(for account: SigningKey, apiClient: ApiClient, options: ClientOptions? = nil) async throws -> PrivateKeyBundleV1 { // swiftlint:disable no_optional_try - if let keys = try await loadPrivateKeys(for: account, apiClient: apiClient) { + if let keys = try await loadPrivateKeys(for: account, apiClient: apiClient, options: options) { // swiftlint:enable no_optional_try print("loading existing private keys.") #if DEBUG @@ -111,11 +121,9 @@ public final class Client: Sendable { #if DEBUG print("No existing keys found, creating new bundle.") #endif - - let keys = try await PrivateKeyBundleV1.generate(wallet: account) + let keys = try await PrivateKeyBundleV1.generate(wallet: account, options: options) let keyBundle = PrivateKeyBundle(v1: keys) - let encryptedKeys = try await keyBundle.encrypted(with: account) - + let encryptedKeys = try await keyBundle.encrypted(with: account, preEnableIdentityCallback: options?.preEnableIdentityCallback) var authorizedIdentity = AuthorizedIdentity(privateKeyBundleV1: keys) authorizedIdentity.address = account.address let authToken = try await authorizedIdentity.createAuthToken() @@ -129,7 +137,7 @@ public final class Client: Sendable { } } - static func loadPrivateKeys(for account: SigningKey, apiClient: ApiClient) async throws -> PrivateKeyBundleV1? { + static func loadPrivateKeys(for account: SigningKey, apiClient: ApiClient, options: ClientOptions? = nil) async throws -> PrivateKeyBundleV1? { let res = try await apiClient.query( topic: .userPrivateStoreKeyBundle(account.address), pagination: nil @@ -137,7 +145,7 @@ public final class Client: Sendable { for envelope in res.envelopes { let encryptedBundle = try EncryptedPrivateKeyBundle(serializedData: envelope.message) - let bundle = try await encryptedBundle.decrypted(with: account) + let bundle = try await encryptedBundle.decrypted(with: account, preEnableIdentityCallback: options?.preEnableIdentityCallback ) if case .v1 = bundle.version { return bundle.v1 } diff --git a/Sources/XMTP/Messages/EncryptedPrivateKeyBundle.swift b/Sources/XMTP/Messages/EncryptedPrivateKeyBundle.swift index 54ab6a5a..1684d31d 100644 --- a/Sources/XMTP/Messages/EncryptedPrivateKeyBundle.swift +++ b/Sources/XMTP/Messages/EncryptedPrivateKeyBundle.swift @@ -8,7 +8,8 @@ typealias EncryptedPrivateKeyBundle = Xmtp_MessageContents_EncryptedPrivateKeyBundle extension EncryptedPrivateKeyBundle { - func decrypted(with key: SigningKey) async throws -> PrivateKeyBundle { + func decrypted(with key: SigningKey, preEnableIdentityCallback: PreEventCallback? = nil) async throws -> PrivateKeyBundle { + try await preEnableIdentityCallback?() let signature = try await key.sign(message: Signature.enableIdentityText(key: v1.walletPreKey)) let message = try Crypto.decrypt(signature.rawDataWithNormalizedRecovery, v1.ciphertext) diff --git a/Sources/XMTP/Messages/PrivateKeyBundle.swift b/Sources/XMTP/Messages/PrivateKeyBundle.swift index 96331dd9..35f36d54 100644 --- a/Sources/XMTP/Messages/PrivateKeyBundle.swift +++ b/Sources/XMTP/Messages/PrivateKeyBundle.swift @@ -20,10 +20,12 @@ extension PrivateKeyBundle { self.v1 = v1 } - func encrypted(with key: SigningKey) async throws -> EncryptedPrivateKeyBundle { + func encrypted(with key: SigningKey, preEnableIdentityCallback: PreEventCallback? = nil) async throws -> EncryptedPrivateKeyBundle { let bundleBytes = try serializedData() let walletPreKey = try Crypto.secureRandomBytes(count: 32) - + + try await preEnableIdentityCallback?() + let signature = try await key.sign(message: Signature.enableIdentityText(key: walletPreKey)) let cipherText = try Crypto.encrypt(signature.rawDataWithNormalizedRecovery, bundleBytes) diff --git a/Sources/XMTP/Messages/PrivateKeyBundleV1.swift b/Sources/XMTP/Messages/PrivateKeyBundleV1.swift index 5161f1d3..504c1ed0 100644 --- a/Sources/XMTP/Messages/PrivateKeyBundleV1.swift +++ b/Sources/XMTP/Messages/PrivateKeyBundleV1.swift @@ -13,9 +13,9 @@ import XMTPRust public typealias PrivateKeyBundleV1 = Xmtp_MessageContents_PrivateKeyBundleV1 extension PrivateKeyBundleV1 { - static func generate(wallet: SigningKey) async throws -> PrivateKeyBundleV1 { + static func generate(wallet: SigningKey, options: ClientOptions? = nil) async throws -> PrivateKeyBundleV1 { let privateKey = try PrivateKey.generate() - let authorizedIdentity = try await wallet.createIdentity(privateKey) + let authorizedIdentity = try await wallet.createIdentity(privateKey, preCreateIdentityCallback: options?.preCreateIdentityCallback) var bundle = try authorizedIdentity.toBundle var preKey = try PrivateKey.generate() diff --git a/Sources/XMTP/SigningKey.swift b/Sources/XMTP/SigningKey.swift index 7b9131a8..4f2b608d 100644 --- a/Sources/XMTP/SigningKey.swift +++ b/Sources/XMTP/SigningKey.swift @@ -29,11 +29,13 @@ public protocol SigningKey { } extension SigningKey { - func createIdentity(_ identity: PrivateKey) async throws -> AuthorizedIdentity { + func createIdentity(_ identity: PrivateKey, preCreateIdentityCallback: PreEventCallback? = nil) async throws -> AuthorizedIdentity { var slimKey = PublicKey() slimKey.timestamp = UInt64(Date().millisecondsSinceEpoch) slimKey.secp256K1Uncompressed = identity.publicKey.secp256K1Uncompressed + try await preCreateIdentityCallback?() + let signatureText = Signature.createIdentityText(key: try slimKey.serializedData()) let signature = try await sign(message: signatureText) diff --git a/Tests/XMTPTests/ClientTests.swift b/Tests/XMTPTests/ClientTests.swift index 7a156c82..7a2e7e7d 100644 --- a/Tests/XMTPTests/ClientTests.swift +++ b/Tests/XMTPTests/ClientTests.swift @@ -100,4 +100,40 @@ class ClientTests: XCTestCase { XCTAssertEqual(recovered, client.keys.identityKey.publicKey.secp256K1Uncompressed.bytes) } + + func testPreEnableIdentityCallback() async throws { + let fakeWallet = try PrivateKey.generate() + let expectation = XCTestExpectation(description: "preEnableIdentityCallback is called") + + let preEnableIdentityCallback: () async throws -> Void = { + print("preEnableIdentityCallback called") + expectation.fulfill() + } + + let opts = ClientOptions(api: ClientOptions.Api(env: .local, isSecure: false), preEnableIdentityCallback: preEnableIdentityCallback ) + do { + _ = try await Client.create(account: fakeWallet, options: opts) + await XCTWaiter().fulfillment(of: [expectation], timeout: 5) + } catch { + XCTFail("Error: \(error)") + } + } + + func testPreCreateIdentityCallback() async throws { + let fakeWallet = try PrivateKey.generate() + let expectation = XCTestExpectation(description: "preCreateIdentityCallback is called") + + let preCreateIdentityCallback: () async throws -> Void = { + print("preCreateIdentityCallback called") + expectation.fulfill() + } + + let opts = ClientOptions(api: ClientOptions.Api(env: .local, isSecure: false), preCreateIdentityCallback: preCreateIdentityCallback ) + do { + _ = try await Client.create(account: fakeWallet, options: opts) + await XCTWaiter().fulfillment(of: [expectation], timeout: 5) + } catch { + XCTFail("Error: \(error)") + } + } } diff --git a/XMTP.podspec b/XMTP.podspec index 795c6ec5..2bd5d041 100644 --- a/XMTP.podspec +++ b/XMTP.podspec @@ -16,7 +16,7 @@ Pod::Spec.new do |spec| # spec.name = "XMTP" - spec.version = "0.7.2-alpha0" + spec.version = "0.7.3-alpha0" spec.summary = "XMTP SDK Cocoapod" # This description is used to generate tags and improve search results.