From 301cf42abe3be98a4aa77018c4c5c62d4a8b9ec3 Mon Sep 17 00:00:00 2001 From: ctrlc03 <93448202+ctrlc03@users.noreply.github.com> Date: Sun, 8 Sep 2024 13:15:04 +0500 Subject: [PATCH] fix(coordinator): cleanup session key service and use plaintext approvals --- .../ts/common/accountAbstraction.ts | 35 ++++++ packages/coordinator/ts/common/errors.ts | 2 + packages/coordinator/ts/common/networks.ts | 110 +++++++++++------- .../__tests__/sessionKeys.service.test.ts | 36 ++---- .../ts/sessionKeys/sessionKeys.module.ts | 4 +- .../ts/sessionKeys/sessionKeys.service.ts | 40 +++---- 6 files changed, 133 insertions(+), 94 deletions(-) diff --git a/packages/coordinator/ts/common/accountAbstraction.ts b/packages/coordinator/ts/common/accountAbstraction.ts index 07c24cec..0621c6e0 100644 --- a/packages/coordinator/ts/common/accountAbstraction.ts +++ b/packages/coordinator/ts/common/accountAbstraction.ts @@ -1,6 +1,8 @@ import dotenv from "dotenv"; +import { type Chain, createPublicClient, http, type HttpTransport, type PublicClient } from "viem"; import { ErrorCodes } from "./errors"; +import { ESupportedNetworks, viemChain } from "./networks"; dotenv.config(); @@ -18,3 +20,36 @@ export const genPimlicoRPCUrl = (network: string): string => { return `https://api.pimlico.io/v2/${network}/rpc?apikey=${pimlicoAPIKey}`; }; + +/** + * Generate the RPCUrl for Alchemy based on the chain we need to interact with + * @param network - the network we want to interact with + * @returns the RPCUrl for the network + */ +export const genAlchemyRPCUrl = (network: ESupportedNetworks): string => { + const rpcAPIKey = process.env.RPC_API_KEY; + + if (!rpcAPIKey) { + throw new Error(ErrorCodes.RPC_API_KEY_NOT_SET); + } + + switch (network) { + case ESupportedNetworks.OPTIMISM_SEPOLIA: + return `https://opt-sepolia.g.alchemy.com/v2/${rpcAPIKey}`; + case ESupportedNetworks.ETHEREUM_SEPOLIA: + return `https://eth-sepolia.g.alchemy.com/v2/${rpcAPIKey}`; + default: + throw new Error(ErrorCodes.UNSUPPORTED_NETWORK); + } +}; + +/** + * Get a public client + * @param rpcUrl - the RPC URL + * @returns the public client + */ +export const getPublicClient = (chainName: ESupportedNetworks): PublicClient => + createPublicClient({ + transport: http(genAlchemyRPCUrl(chainName)), + chain: viemChain(chainName), + }); diff --git a/packages/coordinator/ts/common/errors.ts b/packages/coordinator/ts/common/errors.ts index 293e4aa2..27a5151a 100644 --- a/packages/coordinator/ts/common/errors.ts +++ b/packages/coordinator/ts/common/errors.ts @@ -13,4 +13,6 @@ export enum ErrorCodes { SESSION_KEY_NOT_FOUND = "8", PIMLICO_API_KEY_NOT_SET = "9", INVALID_APPROVAL = "10", + UNSUPPORTED_NETWORK = "11", + RPC_API_KEY_NOT_SET = "12", } diff --git a/packages/coordinator/ts/common/networks.ts b/packages/coordinator/ts/common/networks.ts index a74e922e..65d77512 100644 --- a/packages/coordinator/ts/common/networks.ts +++ b/packages/coordinator/ts/common/networks.ts @@ -1,3 +1,23 @@ +import { + arbitrum, + arbitrumSepolia, + base, + baseSepolia, + bsc, + type Chain, + gnosis, + holesky, + linea, + lineaSepolia, + mainnet, + optimism, + optimismSepolia, + polygon, + scroll, + scrollSepolia, + sepolia, +} from "viem/chains"; + export enum ESupportedNetworks { ETHEREUM = "mainnet", OPTIMISM = "optimism", @@ -5,60 +25,60 @@ export enum ESupportedNetworks { BSC = "bsc", BSC_CHAPEL = "chapel", GNOSIS_CHAIN = "gnosis", - FUSE = "fuse", POLYGON = "matic", - FANTOM_OPERA = "fantom", - ZKSYNC_ERA_TESTNET = "zksync-era-testnet", - BOBA = "boba", - MOONBEAM = "moonbeam", - MOONRIVER = "moonriver", - MOONBASE_ALPHA = "mbase", - FANTOM_TESTNET = "fantom-testnet", ARBITRUM_ONE = "arbitrum-one", - CELO = "celo", - AVALANCHE_FUJI = "fuji", - AVALANCHE = "avalanche", - CELO_ALFAJORES = "celo-alfajores", HOLESKY = "holesky", - AURORA = "aurora", - AURORA_TESTNET = "aurora-testnet", - HARMONY = "harmony", LINEA_SEPOLIA = "linea-sepolia", - GNOSIS_CHIADO = "gnosis-chiado", - MODE_SEPOLIA = "mode-sepolia", - MODE = "mode-mainnet", BASE_SEPOLIA = "base-sepolia", - ZKSYNC_ERA_SEPOLIA = "zksync-era-sepolia", - POLYGON_ZKEVM = "polygon-zkevm", - ZKSYNC_ERA = "zksync-era", ETHEREUM_SEPOLIA = "sepolia", ARBITRUM_SEPOLIA = "arbitrum-sepolia", LINEA = "linea", BASE = "base", SCROLL_SEPOLIA = "scroll-sepolia", SCROLL = "scroll", - BLAST_MAINNET = "blast-mainnet", - ASTAR_ZKEVM_MAINNET = "astar-zkevm-mainnet", - SEI_TESTNET = "sei-testnet", - BLAST_TESTNET = "blast-testnet", - ETHERLINK_TESTNET = "etherlink-testnet", - XLAYER_SEPOLIA = "xlayer-sepolia", - XLAYER_MAINNET = "xlayer-mainnet", - POLYGON_AMOY = "polygon-amoy", - ZKYOTO_TESTNET = "zkyoto-testnet", - POLYGON_ZKEVM_CARDONA = "polygon-zkevm-cardona", - SEI_MAINNET = "sei-mainnet", - ROOTSTOCK_MAINNET = "rootstock", - IOTEX_MAINNET = "iotex", - NEAR_MAINNET = "near-mainnet", - NEAR_TESTNET = "near-testnet", - COSMOS = "cosmoshub-4", - COSMOS_HUB = "theta-testnet-001", - OSMOSIS = "osmosis-1", - OSMO_TESTNET = "osmo-test-4", - ARWEAVE = "arweave-mainnet", - BITCOIN = "btc", - SOLANA = "solana-mainnet-beta", - INJECTIVE_MAINNET = "injective-mainnet", - INJECTIVE_TESTNET = "injective-testnet", } + +/** + * Get the Viem chain for a given network + * + * @param network - the network to get the chain for + * @returns the Viem chain + */ +export const viemChain = (network: ESupportedNetworks): Chain => { + switch (network) { + case ESupportedNetworks.ETHEREUM: + return mainnet; + case ESupportedNetworks.ETHEREUM_SEPOLIA: + return sepolia; + case ESupportedNetworks.ARBITRUM_ONE: + return arbitrum; + case ESupportedNetworks.ARBITRUM_SEPOLIA: + return arbitrumSepolia; + case ESupportedNetworks.BASE_SEPOLIA: + return baseSepolia; + case ESupportedNetworks.LINEA_SEPOLIA: + return lineaSepolia; + case ESupportedNetworks.SCROLL_SEPOLIA: + return scrollSepolia; + case ESupportedNetworks.SCROLL: + return scroll; + case ESupportedNetworks.BASE: + return base; + case ESupportedNetworks.HOLESKY: + return holesky; + case ESupportedNetworks.LINEA: + return linea; + case ESupportedNetworks.BSC: + return bsc; + case ESupportedNetworks.GNOSIS_CHAIN: + return gnosis; + case ESupportedNetworks.POLYGON: + return polygon; + case ESupportedNetworks.OPTIMISM: + return optimism; + case ESupportedNetworks.OPTIMISM_SEPOLIA: + return optimismSepolia; + default: + throw new Error(`Unsupported network: ${network}`); + } +}; diff --git a/packages/coordinator/ts/sessionKeys/__tests__/sessionKeys.service.test.ts b/packages/coordinator/ts/sessionKeys/__tests__/sessionKeys.service.test.ts index 2795da94..6ee18d9a 100644 --- a/packages/coordinator/ts/sessionKeys/__tests__/sessionKeys.service.test.ts +++ b/packages/coordinator/ts/sessionKeys/__tests__/sessionKeys.service.test.ts @@ -1,12 +1,7 @@ import dotenv from "dotenv"; -import { ZeroAddress } from "ethers"; import { zeroAddress } from "viem"; -import { optimismSepolia } from "viem/chains"; -import { KeyLike } from "crypto"; - -import { ErrorCodes } from "../../common"; -import { CryptoService } from "../../crypto/crypto.service"; +import { ErrorCodes, ESupportedNetworks } from "../../common"; import { FileService } from "../../file/file.service"; import { SessionKeysService } from "../sessionKeys.service"; @@ -15,21 +10,18 @@ import { mockSessionKeyApproval } from "./utils"; dotenv.config(); describe("SessionKeysService", () => { - const cryptoService = new CryptoService(); const fileService = new FileService(); let sessionKeysService: SessionKeysService; - let publicKey: KeyLike; - beforeAll(async () => { - publicKey = (await fileService.getPublicKey()).publicKey; - sessionKeysService = new SessionKeysService(cryptoService, fileService); + beforeAll(() => { + sessionKeysService = new SessionKeysService(fileService); }); describe("generateSessionKey", () => { test("should generate and store a session key", () => { const sessionKeyAddress = sessionKeysService.generateSessionKey(); expect(sessionKeyAddress).toBeDefined(); - expect(sessionKeyAddress).not.toEqual(ZeroAddress); + expect(sessionKeyAddress).not.toEqual(zeroAddress); const sessionKey = fileService.getSessionKey(sessionKeyAddress.sessionKeyAddress); expect(sessionKey).toBeDefined(); @@ -40,7 +32,7 @@ describe("SessionKeysService", () => { test("should delete a session key", () => { const sessionKeyAddress = sessionKeysService.generateSessionKey(); expect(sessionKeyAddress).toBeDefined(); - expect(sessionKeyAddress).not.toEqual(ZeroAddress); + expect(sessionKeyAddress).not.toEqual(zeroAddress); const sessionKey = fileService.getSessionKey(sessionKeyAddress.sessionKeyAddress); expect(sessionKey).toBeDefined(); @@ -54,22 +46,19 @@ describe("SessionKeysService", () => { describe("generateClientFromSessionKey", () => { test("should fail to generate a client with an invalid approval", async () => { const sessionKeyAddress = sessionKeysService.generateSessionKey(); - const approval = await mockSessionKeyApproval(sessionKeyAddress.sessionKeyAddress); - const encryptedApproval = cryptoService.encrypt(publicKey, approval); await expect( sessionKeysService.generateClientFromSessionKey( sessionKeyAddress.sessionKeyAddress, - encryptedApproval, - optimismSepolia, + "0xinvalid", + ESupportedNetworks.OPTIMISM_SEPOLIA, ), ).rejects.toThrow(ErrorCodes.INVALID_APPROVAL); }); test("should throw when given a non existent session key address", async () => { const approval = await mockSessionKeyApproval(zeroAddress); - const encryptedApproval = cryptoService.encrypt(publicKey, approval); await expect( - sessionKeysService.generateClientFromSessionKey(zeroAddress, encryptedApproval, optimismSepolia), + sessionKeysService.generateClientFromSessionKey(zeroAddress, approval, ESupportedNetworks.OPTIMISM_SEPOLIA), ).rejects.toThrow(ErrorCodes.SESSION_KEY_NOT_FOUND); }); @@ -86,17 +75,16 @@ describe("SessionKeysService", () => { const sessionKeyAddress = sessionKeysService.generateSessionKey(); const approval = await mockSessionKeyApproval(sessionKeyAddress.sessionKeyAddress); - const encryptedApproval = cryptoService.encrypt(publicKey, approval); const client = await sessionKeysService.generateClientFromSessionKey( sessionKeyAddress.sessionKeyAddress, - encryptedApproval, - optimismSepolia, + approval, + ESupportedNetworks.OPTIMISM_SEPOLIA, ); expect(mockGenerateClientFromSessionKey).toHaveBeenCalledWith( sessionKeyAddress.sessionKeyAddress, - encryptedApproval, - optimismSepolia, + approval, + ESupportedNetworks.OPTIMISM_SEPOLIA, ); expect(client).toEqual({ mockedClient: true }); }); diff --git a/packages/coordinator/ts/sessionKeys/sessionKeys.module.ts b/packages/coordinator/ts/sessionKeys/sessionKeys.module.ts index fdc40cb3..4c48e381 100644 --- a/packages/coordinator/ts/sessionKeys/sessionKeys.module.ts +++ b/packages/coordinator/ts/sessionKeys/sessionKeys.module.ts @@ -1,14 +1,14 @@ import { Module } from "@nestjs/common"; -import { CryptoModule } from "../crypto/crypto.module"; import { FileModule } from "../file/file.module"; import { SessionKeysController } from "./sessionKeys.controller"; import { SessionKeysService } from "./sessionKeys.service"; @Module({ - imports: [FileModule, CryptoModule], + imports: [FileModule], controllers: [SessionKeysController], providers: [SessionKeysService], + exports: [SessionKeysService], }) export class SessionKeysModule {} diff --git a/packages/coordinator/ts/sessionKeys/sessionKeys.service.ts b/packages/coordinator/ts/sessionKeys/sessionKeys.service.ts index 2ecd67bd..27825573 100644 --- a/packages/coordinator/ts/sessionKeys/sessionKeys.service.ts +++ b/packages/coordinator/ts/sessionKeys/sessionKeys.service.ts @@ -5,12 +5,14 @@ import { createKernelAccountClient, KernelAccountClient, KernelSmartAccount } fr import { KERNEL_V3_1 } from "@zerodev/sdk/constants"; import { ENTRYPOINT_ADDRESS_V07 } from "permissionless"; import { ENTRYPOINT_ADDRESS_V07_TYPE } from "permissionless/types"; -import { type Chain, createPublicClient, type Hex, http, type HttpTransport, type Transport } from "viem"; +import { http } from "viem"; import { generatePrivateKey, privateKeyToAccount } from "viem/accounts"; -import { ErrorCodes } from "../common"; -import { genPimlicoRPCUrl } from "../common/accountAbstraction"; -import { CryptoService } from "../crypto/crypto.service"; +import type { Chain, Hex, HttpTransport, Transport } from "viem"; + +import { ErrorCodes, ESupportedNetworks } from "../common"; +import { genPimlicoRPCUrl, getPublicClient } from "../common/accountAbstraction"; +import { viemChain } from "../common/networks"; import { FileService } from "../file/file.service"; import { IGenerateSessionKeyReturn } from "./types"; @@ -31,10 +33,7 @@ export class SessionKeysService { * @param cryptoService - crypto service * @param fileService - file service */ - constructor( - private readonly cryptoService: CryptoService, - private readonly fileService: FileService, - ) { + constructor(private readonly fileService: FileService) { this.logger = new Logger(SessionKeysService.name); } @@ -64,26 +63,22 @@ export class SessionKeysService { * Generate a KernelClient from a session key and an approval * * @param sessionKeyAddress - the address of the session key - * @param encryptedApproval - the encrypted approval string + * @param approval - the approval string * @param chain - the chain to use - * @returns + * @returns a KernelAccountClient */ async generateClientFromSessionKey( sessionKeyAddress: Hex, - encryptedApproval: string, - chain: Chain, + approval: string, + chain: ESupportedNetworks, ): Promise< KernelAccountClient< ENTRYPOINT_ADDRESS_V07_TYPE, Transport, - undefined, - KernelSmartAccount + Chain, + KernelSmartAccount > > { - // the approval will have been encrypted so we need to decrypt it - const { privateKey } = await this.fileService.getPrivateKey(); - const approval = this.cryptoService.decrypt(privateKey, encryptedApproval); - // retrieve the session key from the file service const sessionKey = this.fileService.getSessionKey(sessionKeyAddress); @@ -93,10 +88,8 @@ export class SessionKeysService { } // get the bundler url and create a public client - const bundlerUrl = genPimlicoRPCUrl(chain.name); - const publicClient = createPublicClient({ - transport: http(bundlerUrl), - }); + const bundlerUrl = genPimlicoRPCUrl(chain); + const publicClient = getPublicClient(chain); // Using a stored private key const sessionKeySigner = toECDSASigner({ @@ -117,9 +110,10 @@ export class SessionKeysService { bundlerTransport: http(bundlerUrl), entryPoint: ENTRYPOINT_ADDRESS_V07, account: sessionKeyAccount, + chain: viemChain(chain), }); } catch (error) { - this.logger.error("Error deserializing permission account", error); + this.logger.error(`Error: ${ErrorCodes.INVALID_APPROVAL}`, error); throw new Error(ErrorCodes.INVALID_APPROVAL); } }