Skip to content

Commit

Permalink
feat(protocol-kit): Predict Safe address improvements (#982)
Browse files Browse the repository at this point in the history
  • Loading branch information
yagopv authored Sep 26, 2024
1 parent 2f60640 commit dc602ed
Show file tree
Hide file tree
Showing 52 changed files with 850 additions and 196 deletions.
27 changes: 20 additions & 7 deletions packages/protocol-kit/src/Safe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import {
getPredictedSafeAddressInitCode,
predictSafeAddress
} from './contracts/utils'
import { DEFAULT_SAFE_VERSION } from './contracts/config'
import { ContractInfo, DEFAULT_SAFE_VERSION, getContractInfo } from './contracts/config'
import ContractManager from './managers/contractManager'
import FallbackHandlerManager from './managers/fallbackHandlerManager'
import GuardManager from './managers/guardManager'
Expand Down Expand Up @@ -124,12 +124,12 @@ class Safe {
async #initializeProtocolKit(config: SafeConfig) {
const { provider, signer, isL1SafeSingleton, contractNetworks } = config

this.#safeProvider = await SafeProvider.init(
this.#safeProvider = await SafeProvider.init({
provider,
signer,
DEFAULT_SAFE_VERSION,
safeVersion: DEFAULT_SAFE_VERSION,
contractNetworks
)
})

if (isSafeConfigWithPredictedSafe(config)) {
this.#predictedSafe = config.predictedSafe
Expand All @@ -155,7 +155,12 @@ class Safe {
}

const safeVersion = this.getContractVersion()
this.#safeProvider = await SafeProvider.init(provider, signer, safeVersion, contractNetworks)
this.#safeProvider = await SafeProvider.init({
provider,
signer,
safeVersion,
contractNetworks
})

this.#ownerManager = new OwnerManager(this.#safeProvider, this.#contractManager.safeContract)
this.#moduleManager = new ModuleManager(this.#safeProvider, this.#contractManager.safeContract)
Expand All @@ -169,14 +174,14 @@ class Safe {
if (isPasskeySigner) {
const safeAddress = await this.getAddress()
const owners = await this.getOwners()
this.#safeProvider = await SafeProvider.init(
this.#safeProvider = await SafeProvider.init({
provider,
signer,
safeVersion,
contractNetworks,
safeAddress,
owners
)
})
}
}

Expand Down Expand Up @@ -1636,6 +1641,14 @@ class Safe {
return false
}
}

getContractInfo = ({
contractAddress
}: {
contractAddress: string
}): ContractInfo | undefined => {
return getContractInfo(contractAddress)
}
}

export default Safe
7 changes: 6 additions & 1 deletion packages/protocol-kit/src/SafeFactory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@ class SafeFactory {
}: SafeFactoryInitConfig) {
this.#provider = provider
this.#signer = signer
this.#safeProvider = await SafeProvider.init(provider, signer, safeVersion, contractNetworks)
this.#safeProvider = await SafeProvider.init({
provider,
signer,
safeVersion,
contractNetworks
})
this.#safeVersion = safeVersion
this.#isL1SafeSingleton = isL1SafeSingleton
this.#contractNetworks = contractNetworks
Expand Down
22 changes: 10 additions & 12 deletions packages/protocol-kit/src/SafeProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@ import {
EIP712TypedDataMessage,
EIP712TypedDataTx,
Eip3770Address,
SafeEIP712Args,
SafeVersion
SafeEIP712Args
} from '@safe-global/types-kit'
import {
SafeProviderTransaction,
SafeProviderConfig,
SafeProviderInitOptions,
ExternalClient,
ExternalSigner,
Eip1193Provider,
HttpTransport,
SocketTransport,
SafeSigner,
SafeConfig,
ContractNetworksConfig,
PasskeyArgType,
PasskeyClient
} from '@safe-global/protocol-kit/types'
Expand Down Expand Up @@ -102,14 +100,14 @@ class SafeProvider {
return this.#externalProvider
}

static async init(
provider: SafeConfig['provider'],
signer?: SafeConfig['signer'],
safeVersion: SafeVersion = DEFAULT_SAFE_VERSION,
contractNetworks?: ContractNetworksConfig,
safeAddress?: string,
owners?: string[]
): Promise<SafeProvider> {
static async init({
provider,
signer,
safeVersion = DEFAULT_SAFE_VERSION,
contractNetworks,
safeAddress,
owners
}: SafeProviderInitOptions): Promise<SafeProvider> {
const isPasskeySigner = signer && typeof signer !== 'string'

if (isPasskeySigner) {
Expand Down
47 changes: 39 additions & 8 deletions packages/protocol-kit/src/contracts/BaseContract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import {
Chain
} from 'viem'
import { estimateContractGas, getTransactionReceipt } from 'viem/actions'
import { SingletonDeployment } from '@safe-global/safe-deployments'
import { contractName, getContractDeployment } from '@safe-global/protocol-kit/contracts/config'
import { DeploymentType } from '@safe-global/protocol-kit/types'
import SafeProvider from '@safe-global/protocol-kit/SafeProvider'
import {
EncodeFunction,
Expand Down Expand Up @@ -62,6 +64,7 @@ class BaseContract<ContractAbiType extends Abi> {
* @param safeVersion - The version of the Safe contract.
* @param customContractAddress - Optional custom address for the contract. If not provided, the address is derived from the Safe deployments based on the chainId and safeVersion.
* @param customContractAbi - Optional custom ABI for the contract. If not provided, the ABI is derived from the Safe deployments or the defaultAbi is used.
* @param deploymentType - Optional deployment type for the contract. If not provided, the first deployment retrieved from the safe-deployments array will be used.
*/
constructor(
contractName: contractName,
Expand All @@ -70,24 +73,27 @@ class BaseContract<ContractAbiType extends Abi> {
defaultAbi: ContractAbiType,
safeVersion: SafeVersion,
customContractAddress?: string,
customContractAbi?: ContractAbiType
customContractAbi?: ContractAbiType,
deploymentType?: DeploymentType
) {
const deployment = getContractDeployment(safeVersion, chainId, contractName)

const contractAddress =
customContractAddress || deployment?.networkAddresses[chainId.toString()]
const resolvedAddress =
customContractAddress ??
this.#resolveAddress(
deployment?.networkAddresses[chainId.toString()],
deployment,
deploymentType
)

if (!contractAddress) {
if (!resolvedAddress) {
throw new Error(`Invalid ${contractName.replace('Version', '')} contract address`)
}

this.chainId = chainId
this.contractName = contractName
this.safeVersion = safeVersion
this.contractAddress =
Array.isArray(contractAddress) && contractAddress.length
? contractAddress[0]
: contractAddress.toString()
this.contractAddress = resolvedAddress
this.contractAbi =
customContractAbi ||
(deployment?.abi as unknown as ContractAbiType) || // this cast is required because abi is set as any[] in safe-deployments
Expand All @@ -97,6 +103,31 @@ class BaseContract<ContractAbiType extends Abi> {
this.safeProvider = safeProvider
}

#resolveAddress(
networkAddresses: string | string[] | undefined,
deployment: SingletonDeployment,
deploymentType?: DeploymentType
): string | undefined {
if (!networkAddresses) {
return undefined
}

if (typeof networkAddresses === 'string') {
return networkAddresses
}

if (deploymentType) {
const customDeploymentTypeAddress = deployment.deployments[deploymentType]?.address

return (
networkAddresses.find((address) => address === customDeploymentTypeAddress) ??
networkAddresses[0]
)
}

return networkAddresses[0]
}

async init() {
this.wallet = await this.safeProvider.getExternalSigner()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Abi } from 'abitype'

import SafeProvider from '@safe-global/protocol-kit/SafeProvider'
import BaseContract from '@safe-global/protocol-kit/contracts/BaseContract'
import { DeploymentType } from '@safe-global/protocol-kit/types'
import { SafeVersion } from '@safe-global/types-kit'
import { contractName } from '@safe-global/protocol-kit/contracts/config'

Expand Down Expand Up @@ -33,14 +34,16 @@ abstract class CompatibilityFallbackHandlerBaseContract<
* @param safeVersion - The version of the Safe contract.
* @param customContractAddress - Optional custom address for the contract. If not provided, the address is derived from the Safe deployments based on the chainId and safeVersion.
* @param customContractAbi - Optional custom ABI for the contract. If not provided, the ABI is derived from the Safe deployments or the defaultAbi is used.
* @param deploymentType - Optional deployment type for the contract. If not provided, the first deployment retrieved from the safe-deployments array will be used.
*/
constructor(
chainId: bigint,
safeProvider: SafeProvider,
defaultAbi: CompatibilityFallbackHandlerContractAbiType,
safeVersion: SafeVersion,
customContractAddress?: string,
customContractAbi?: CompatibilityFallbackHandlerContractAbiType
customContractAbi?: CompatibilityFallbackHandlerContractAbiType,
deploymentType?: DeploymentType
) {
const contractName = 'compatibilityFallbackHandler'

Expand All @@ -51,7 +54,8 @@ abstract class CompatibilityFallbackHandlerBaseContract<
defaultAbi,
safeVersion,
customContractAddress,
customContractAbi
customContractAbi,
deploymentType
)

this.contractName = contractName
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import CompatibilityFallbackHandlerBaseContract from '@safe-global/protocol-kit/contracts/CompatibilityFallbackHandler/CompatibilityFallbackHandlerBaseContract'
import SafeProvider from '@safe-global/protocol-kit/SafeProvider'
import { DeploymentType } from '@safe-global/protocol-kit/types'
import {
CompatibilityFallbackHandlerContract_v1_3_0_Abi,
CompatibilityFallbackHandlerContract_v1_3_0_Contract,
Expand All @@ -25,17 +26,27 @@ class CompatibilityFallbackHandlerContract_v1_3_0
* @param safeProvider - An instance of SafeProvider.
* @param customContractAddress - Optional custom address for the contract. If not provided, the address is derived from the CompatibilityFallbackHandler deployments based on the chainId and safeVersion.
* @param customContractAbi - Optional custom ABI for the contract. If not provided, the default ABI for version 1.3.0 is used.
* @param deploymentType - Optional deployment type for the contract. If not provided, the first deployment retrieved from the safe-deployments array will be used.
*/
constructor(
chainId: bigint,
safeProvider: SafeProvider,
customContractAddress?: string,
customContractAbi?: CompatibilityFallbackHandlerContract_v1_3_0_Abi
customContractAbi?: CompatibilityFallbackHandlerContract_v1_3_0_Abi,
deploymentType?: DeploymentType
) {
const safeVersion = '1.3.0'
const defaultAbi = compatibilityFallbackHandler_1_3_0_ContractArtifacts.abi

super(chainId, safeProvider, defaultAbi, safeVersion, customContractAddress, customContractAbi)
super(
chainId,
safeProvider,
defaultAbi,
safeVersion,
customContractAddress,
customContractAbi,
deploymentType
)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import CompatibilityFallbackHandlerBaseContract from '@safe-global/protocol-kit/contracts/CompatibilityFallbackHandler/CompatibilityFallbackHandlerBaseContract'
import SafeProvider from '@safe-global/protocol-kit/SafeProvider'
import { DeploymentType } from '@safe-global/protocol-kit/types'
import {
compatibilityFallbackHandler_1_4_1_ContractArtifacts,
CompatibilityFallbackHandlerContract_v1_4_1_Abi,
Expand All @@ -25,17 +26,27 @@ class CompatibilityFallbackHandlerContract_v1_4_1
* @param safeProvider - An instance of SafeProvider.
* @param customContractAddress - Optional custom address for the contract. If not provided, the address is derived from the CompatibilityFallbackHandler deployments based on the chainId and safeVersion.
* @param customContractAbi - Optional custom ABI for the contract. If not provided, the default ABI for version 1.4.1 is used.
* @param deploymentType - Optional deployment type for the contract. If not provided, the first deployment retrieved from the safe-deployments array will be used.
*/
constructor(
chainId: bigint,
safeProvider: SafeProvider,
customContractAddress?: string,
customContractAbi?: CompatibilityFallbackHandlerContract_v1_4_1_Abi
customContractAbi?: CompatibilityFallbackHandlerContract_v1_4_1_Abi,
deploymentType?: DeploymentType
) {
const safeVersion = '1.4.1'
const defaultAbi = compatibilityFallbackHandler_1_4_1_ContractArtifacts.abi

super(chainId, safeProvider, defaultAbi, safeVersion, customContractAddress, customContractAbi)
super(
chainId,
safeProvider,
defaultAbi,
safeVersion,
customContractAddress,
customContractAbi,
deploymentType
)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Abi } from 'abitype'

import SafeProvider from '@safe-global/protocol-kit/SafeProvider'
import { DeploymentType } from '@safe-global/protocol-kit/types'
import BaseContract from '@safe-global/protocol-kit/contracts/BaseContract'
import { SafeVersion } from '@safe-global/types-kit'
import { contractName } from '@safe-global/protocol-kit/contracts/config'
Expand Down Expand Up @@ -33,14 +34,16 @@ abstract class CreateCallBaseContract<
* @param safeVersion - The version of the Safe contract.
* @param customContractAddress - Optional custom address for the contract. If not provided, the address is derived from the Safe deployments based on the chainId and safeVersion.
* @param customContractAbi - Optional custom ABI for the contract. If not provided, the ABI is derived from the Safe deployments or the defaultAbi is used.
* @param deploymentType - Optional deployment type for the contract. If not provided, the first deployment retrieved from the safe-deployments array will be used.
*/
constructor(
chainId: bigint,
safeProvider: SafeProvider,
defaultAbi: CreateCallContractAbiType,
safeVersion: SafeVersion,
customContractAddress?: string,
customContractAbi?: CreateCallContractAbiType
customContractAbi?: CreateCallContractAbiType,
deploymentType?: DeploymentType
) {
const contractName = 'createCallVersion'

Expand All @@ -51,7 +54,8 @@ abstract class CreateCallBaseContract<
defaultAbi,
safeVersion,
customContractAddress,
customContractAbi
customContractAbi,
deploymentType
)

this.contractName = contractName
Expand Down
Loading

0 comments on commit dc602ed

Please sign in to comment.