Skip to content

Commit

Permalink
fix(coordinator): cleanup session key service and use plaintext appro…
Browse files Browse the repository at this point in the history
…vals
  • Loading branch information
ctrlc03 committed Sep 8, 2024
1 parent a5b3609 commit 301cf42
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 94 deletions.
35 changes: 35 additions & 0 deletions packages/coordinator/ts/common/accountAbstraction.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import dotenv from "dotenv";
import { type Chain, createPublicClient, http, type HttpTransport, type PublicClient } from "viem";

import { ErrorCodes } from "./errors";
import { ESupportedNetworks, viemChain } from "./networks";

dotenv.config();

Expand All @@ -18,3 +20,36 @@ export const genPimlicoRPCUrl = (network: string): string => {

return `https://api.pimlico.io/v2/${network}/rpc?apikey=${pimlicoAPIKey}`;
};

/**
* Generate the RPCUrl for Alchemy based on the chain we need to interact with
* @param network - the network we want to interact with
* @returns the RPCUrl for the network
*/
export const genAlchemyRPCUrl = (network: ESupportedNetworks): string => {
const rpcAPIKey = process.env.RPC_API_KEY;

if (!rpcAPIKey) {
throw new Error(ErrorCodes.RPC_API_KEY_NOT_SET);
}

switch (network) {
case ESupportedNetworks.OPTIMISM_SEPOLIA:
return `https://opt-sepolia.g.alchemy.com/v2/${rpcAPIKey}`;
case ESupportedNetworks.ETHEREUM_SEPOLIA:
return `https://eth-sepolia.g.alchemy.com/v2/${rpcAPIKey}`;
default:
throw new Error(ErrorCodes.UNSUPPORTED_NETWORK);
}
};

/**
* Get a public client
* @param rpcUrl - the RPC URL
* @returns the public client
*/
export const getPublicClient = (chainName: ESupportedNetworks): PublicClient<HttpTransport, Chain> =>
createPublicClient({
transport: http(genAlchemyRPCUrl(chainName)),
chain: viemChain(chainName),
});
2 changes: 2 additions & 0 deletions packages/coordinator/ts/common/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ export enum ErrorCodes {
SESSION_KEY_NOT_FOUND = "8",
PIMLICO_API_KEY_NOT_SET = "9",
INVALID_APPROVAL = "10",
UNSUPPORTED_NETWORK = "11",
RPC_API_KEY_NOT_SET = "12",
}
110 changes: 65 additions & 45 deletions packages/coordinator/ts/common/networks.ts
Original file line number Diff line number Diff line change
@@ -1,64 +1,84 @@
import {
arbitrum,
arbitrumSepolia,
base,
baseSepolia,
bsc,
type Chain,
gnosis,
holesky,
linea,
lineaSepolia,
mainnet,
optimism,
optimismSepolia,
polygon,
scroll,
scrollSepolia,
sepolia,
} from "viem/chains";

export enum ESupportedNetworks {
ETHEREUM = "mainnet",
OPTIMISM = "optimism",
OPTIMISM_SEPOLIA = "optimism-sepolia",
BSC = "bsc",
BSC_CHAPEL = "chapel",
GNOSIS_CHAIN = "gnosis",
FUSE = "fuse",
POLYGON = "matic",
FANTOM_OPERA = "fantom",
ZKSYNC_ERA_TESTNET = "zksync-era-testnet",
BOBA = "boba",
MOONBEAM = "moonbeam",
MOONRIVER = "moonriver",
MOONBASE_ALPHA = "mbase",
FANTOM_TESTNET = "fantom-testnet",
ARBITRUM_ONE = "arbitrum-one",
CELO = "celo",
AVALANCHE_FUJI = "fuji",
AVALANCHE = "avalanche",
CELO_ALFAJORES = "celo-alfajores",
HOLESKY = "holesky",
AURORA = "aurora",
AURORA_TESTNET = "aurora-testnet",
HARMONY = "harmony",
LINEA_SEPOLIA = "linea-sepolia",
GNOSIS_CHIADO = "gnosis-chiado",
MODE_SEPOLIA = "mode-sepolia",
MODE = "mode-mainnet",
BASE_SEPOLIA = "base-sepolia",
ZKSYNC_ERA_SEPOLIA = "zksync-era-sepolia",
POLYGON_ZKEVM = "polygon-zkevm",
ZKSYNC_ERA = "zksync-era",
ETHEREUM_SEPOLIA = "sepolia",
ARBITRUM_SEPOLIA = "arbitrum-sepolia",
LINEA = "linea",
BASE = "base",
SCROLL_SEPOLIA = "scroll-sepolia",
SCROLL = "scroll",
BLAST_MAINNET = "blast-mainnet",
ASTAR_ZKEVM_MAINNET = "astar-zkevm-mainnet",
SEI_TESTNET = "sei-testnet",
BLAST_TESTNET = "blast-testnet",
ETHERLINK_TESTNET = "etherlink-testnet",
XLAYER_SEPOLIA = "xlayer-sepolia",
XLAYER_MAINNET = "xlayer-mainnet",
POLYGON_AMOY = "polygon-amoy",
ZKYOTO_TESTNET = "zkyoto-testnet",
POLYGON_ZKEVM_CARDONA = "polygon-zkevm-cardona",
SEI_MAINNET = "sei-mainnet",
ROOTSTOCK_MAINNET = "rootstock",
IOTEX_MAINNET = "iotex",
NEAR_MAINNET = "near-mainnet",
NEAR_TESTNET = "near-testnet",
COSMOS = "cosmoshub-4",
COSMOS_HUB = "theta-testnet-001",
OSMOSIS = "osmosis-1",
OSMO_TESTNET = "osmo-test-4",
ARWEAVE = "arweave-mainnet",
BITCOIN = "btc",
SOLANA = "solana-mainnet-beta",
INJECTIVE_MAINNET = "injective-mainnet",
INJECTIVE_TESTNET = "injective-testnet",
}

/**
* Get the Viem chain for a given network
*
* @param network - the network to get the chain for
* @returns the Viem chain
*/
export const viemChain = (network: ESupportedNetworks): Chain => {
switch (network) {
case ESupportedNetworks.ETHEREUM:
return mainnet;
case ESupportedNetworks.ETHEREUM_SEPOLIA:
return sepolia;
case ESupportedNetworks.ARBITRUM_ONE:
return arbitrum;
case ESupportedNetworks.ARBITRUM_SEPOLIA:
return arbitrumSepolia;
case ESupportedNetworks.BASE_SEPOLIA:
return baseSepolia;
case ESupportedNetworks.LINEA_SEPOLIA:
return lineaSepolia;
case ESupportedNetworks.SCROLL_SEPOLIA:
return scrollSepolia;
case ESupportedNetworks.SCROLL:
return scroll;
case ESupportedNetworks.BASE:
return base;
case ESupportedNetworks.HOLESKY:
return holesky;
case ESupportedNetworks.LINEA:
return linea;
case ESupportedNetworks.BSC:
return bsc;
case ESupportedNetworks.GNOSIS_CHAIN:
return gnosis;
case ESupportedNetworks.POLYGON:
return polygon;
case ESupportedNetworks.OPTIMISM:
return optimism;
case ESupportedNetworks.OPTIMISM_SEPOLIA:
return optimismSepolia;
default:
throw new Error(`Unsupported network: ${network}`);
}
};
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import dotenv from "dotenv";
import { ZeroAddress } from "ethers";
import { zeroAddress } from "viem";
import { optimismSepolia } from "viem/chains";

import { KeyLike } from "crypto";

import { ErrorCodes } from "../../common";
import { CryptoService } from "../../crypto/crypto.service";
import { ErrorCodes, ESupportedNetworks } from "../../common";
import { FileService } from "../../file/file.service";
import { SessionKeysService } from "../sessionKeys.service";

Expand All @@ -15,21 +10,18 @@ import { mockSessionKeyApproval } from "./utils";
dotenv.config();

describe("SessionKeysService", () => {
const cryptoService = new CryptoService();
const fileService = new FileService();
let sessionKeysService: SessionKeysService;
let publicKey: KeyLike;

beforeAll(async () => {
publicKey = (await fileService.getPublicKey()).publicKey;
sessionKeysService = new SessionKeysService(cryptoService, fileService);
beforeAll(() => {
sessionKeysService = new SessionKeysService(fileService);
});

describe("generateSessionKey", () => {
test("should generate and store a session key", () => {
const sessionKeyAddress = sessionKeysService.generateSessionKey();
expect(sessionKeyAddress).toBeDefined();
expect(sessionKeyAddress).not.toEqual(ZeroAddress);
expect(sessionKeyAddress).not.toEqual(zeroAddress);

const sessionKey = fileService.getSessionKey(sessionKeyAddress.sessionKeyAddress);
expect(sessionKey).toBeDefined();
Expand All @@ -40,7 +32,7 @@ describe("SessionKeysService", () => {
test("should delete a session key", () => {
const sessionKeyAddress = sessionKeysService.generateSessionKey();
expect(sessionKeyAddress).toBeDefined();
expect(sessionKeyAddress).not.toEqual(ZeroAddress);
expect(sessionKeyAddress).not.toEqual(zeroAddress);

const sessionKey = fileService.getSessionKey(sessionKeyAddress.sessionKeyAddress);
expect(sessionKey).toBeDefined();
Expand All @@ -54,22 +46,19 @@ describe("SessionKeysService", () => {
describe("generateClientFromSessionKey", () => {
test("should fail to generate a client with an invalid approval", async () => {
const sessionKeyAddress = sessionKeysService.generateSessionKey();
const approval = await mockSessionKeyApproval(sessionKeyAddress.sessionKeyAddress);
const encryptedApproval = cryptoService.encrypt(publicKey, approval);
await expect(
sessionKeysService.generateClientFromSessionKey(
sessionKeyAddress.sessionKeyAddress,
encryptedApproval,
optimismSepolia,
"0xinvalid",
ESupportedNetworks.OPTIMISM_SEPOLIA,
),
).rejects.toThrow(ErrorCodes.INVALID_APPROVAL);
});

test("should throw when given a non existent session key address", async () => {
const approval = await mockSessionKeyApproval(zeroAddress);
const encryptedApproval = cryptoService.encrypt(publicKey, approval);
await expect(
sessionKeysService.generateClientFromSessionKey(zeroAddress, encryptedApproval, optimismSepolia),
sessionKeysService.generateClientFromSessionKey(zeroAddress, approval, ESupportedNetworks.OPTIMISM_SEPOLIA),
).rejects.toThrow(ErrorCodes.SESSION_KEY_NOT_FOUND);
});

Expand All @@ -86,17 +75,16 @@ describe("SessionKeysService", () => {

const sessionKeyAddress = sessionKeysService.generateSessionKey();
const approval = await mockSessionKeyApproval(sessionKeyAddress.sessionKeyAddress);
const encryptedApproval = cryptoService.encrypt(publicKey, approval);

const client = await sessionKeysService.generateClientFromSessionKey(
sessionKeyAddress.sessionKeyAddress,
encryptedApproval,
optimismSepolia,
approval,
ESupportedNetworks.OPTIMISM_SEPOLIA,
);
expect(mockGenerateClientFromSessionKey).toHaveBeenCalledWith(
sessionKeyAddress.sessionKeyAddress,
encryptedApproval,
optimismSepolia,
approval,
ESupportedNetworks.OPTIMISM_SEPOLIA,
);
expect(client).toEqual({ mockedClient: true });
});
Expand Down
4 changes: 2 additions & 2 deletions packages/coordinator/ts/sessionKeys/sessionKeys.module.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import { Module } from "@nestjs/common";

import { CryptoModule } from "../crypto/crypto.module";
import { FileModule } from "../file/file.module";

import { SessionKeysController } from "./sessionKeys.controller";
import { SessionKeysService } from "./sessionKeys.service";

@Module({
imports: [FileModule, CryptoModule],
imports: [FileModule],
controllers: [SessionKeysController],
providers: [SessionKeysService],
exports: [SessionKeysService],
})
export class SessionKeysModule {}
40 changes: 17 additions & 23 deletions packages/coordinator/ts/sessionKeys/sessionKeys.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import { createKernelAccountClient, KernelAccountClient, KernelSmartAccount } fr
import { KERNEL_V3_1 } from "@zerodev/sdk/constants";
import { ENTRYPOINT_ADDRESS_V07 } from "permissionless";
import { ENTRYPOINT_ADDRESS_V07_TYPE } from "permissionless/types";
import { type Chain, createPublicClient, type Hex, http, type HttpTransport, type Transport } from "viem";
import { http } from "viem";
import { generatePrivateKey, privateKeyToAccount } from "viem/accounts";

import { ErrorCodes } from "../common";
import { genPimlicoRPCUrl } from "../common/accountAbstraction";
import { CryptoService } from "../crypto/crypto.service";
import type { Chain, Hex, HttpTransport, Transport } from "viem";

import { ErrorCodes, ESupportedNetworks } from "../common";
import { genPimlicoRPCUrl, getPublicClient } from "../common/accountAbstraction";
import { viemChain } from "../common/networks";
import { FileService } from "../file/file.service";

import { IGenerateSessionKeyReturn } from "./types";
Expand All @@ -31,10 +33,7 @@ export class SessionKeysService {
* @param cryptoService - crypto service
* @param fileService - file service
*/
constructor(
private readonly cryptoService: CryptoService,
private readonly fileService: FileService,
) {
constructor(private readonly fileService: FileService) {
this.logger = new Logger(SessionKeysService.name);
}

Expand Down Expand Up @@ -64,26 +63,22 @@ export class SessionKeysService {
* Generate a KernelClient from a session key and an approval
*
* @param sessionKeyAddress - the address of the session key
* @param encryptedApproval - the encrypted approval string
* @param approval - the approval string
* @param chain - the chain to use
* @returns
* @returns a KernelAccountClient
*/
async generateClientFromSessionKey(
sessionKeyAddress: Hex,
encryptedApproval: string,
chain: Chain,
approval: string,
chain: ESupportedNetworks,
): Promise<
KernelAccountClient<
ENTRYPOINT_ADDRESS_V07_TYPE,
Transport,
undefined,
KernelSmartAccount<ENTRYPOINT_ADDRESS_V07_TYPE, HttpTransport, undefined>
Chain,
KernelSmartAccount<ENTRYPOINT_ADDRESS_V07_TYPE, HttpTransport, Chain>
>
> {
// the approval will have been encrypted so we need to decrypt it
const { privateKey } = await this.fileService.getPrivateKey();
const approval = this.cryptoService.decrypt(privateKey, encryptedApproval);

// retrieve the session key from the file service
const sessionKey = this.fileService.getSessionKey(sessionKeyAddress);

Expand All @@ -93,10 +88,8 @@ export class SessionKeysService {
}

// get the bundler url and create a public client
const bundlerUrl = genPimlicoRPCUrl(chain.name);
const publicClient = createPublicClient({
transport: http(bundlerUrl),
});
const bundlerUrl = genPimlicoRPCUrl(chain);
const publicClient = getPublicClient(chain);

// Using a stored private key
const sessionKeySigner = toECDSASigner({
Expand All @@ -117,9 +110,10 @@ export class SessionKeysService {
bundlerTransport: http(bundlerUrl),
entryPoint: ENTRYPOINT_ADDRESS_V07,
account: sessionKeyAccount,
chain: viemChain(chain),
});
} catch (error) {
this.logger.error("Error deserializing permission account", error);
this.logger.error(`Error: ${ErrorCodes.INVALID_APPROVAL}`, error);
throw new Error(ErrorCodes.INVALID_APPROVAL);
}
}
Expand Down

0 comments on commit 301cf42

Please sign in to comment.