From 7981ee1f6c1b9e4734da0c99b84e3920a60c9de5 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Mon, 21 Oct 2024 16:46:50 +0800 Subject: [PATCH] chore: refactor `starkNet_getCurrentNetwork` api --- .../starknet-snap/src/getCurrentNetwork.ts | 20 ------- packages/starknet-snap/src/index.ts | 4 +- .../src/rpcs/get-current-network.test.ts | 38 ++++++++++++ .../src/rpcs/get-current-network.ts | 58 +++++++++++++++++++ packages/starknet-snap/src/rpcs/index.ts | 1 + .../test/src/getCurrentNetwork.test.ts | 56 ------------------ 6 files changed, 99 insertions(+), 78 deletions(-) delete mode 100644 packages/starknet-snap/src/getCurrentNetwork.ts create mode 100644 packages/starknet-snap/src/rpcs/get-current-network.test.ts create mode 100644 packages/starknet-snap/src/rpcs/get-current-network.ts delete mode 100644 packages/starknet-snap/test/src/getCurrentNetwork.test.ts diff --git a/packages/starknet-snap/src/getCurrentNetwork.ts b/packages/starknet-snap/src/getCurrentNetwork.ts deleted file mode 100644 index f09b0248..00000000 --- a/packages/starknet-snap/src/getCurrentNetwork.ts +++ /dev/null @@ -1,20 +0,0 @@ -import type { ApiParams } from './types/snapApi'; -import { logger } from './utils/logger'; -import { toJson } from './utils/serializer'; -import { getCurrentNetwork as getCurrentNetworkUtil } from './utils/snapUtils'; - -/** - * - * @param params - */ -export async function getCurrentNetwork(params: ApiParams) { - try { - const { state } = params; - const networks = getCurrentNetworkUtil(state); - logger.log(`getCurrentNetwork: networks:\n${toJson(networks, 2)}`); - return networks; - } catch (error) { - logger.error(`Problem found:`, error); - throw error; - } -} diff --git a/packages/starknet-snap/src/index.ts b/packages/starknet-snap/src/index.ts index 80a5349c..a9fdf27a 100644 --- a/packages/starknet-snap/src/index.ts +++ b/packages/starknet-snap/src/index.ts @@ -13,7 +13,6 @@ import { declareContract } from './declareContract'; import { estimateAccDeployFee } from './estimateAccountDeployFee'; import { estimateFees } from './estimateFees'; import { extractPublicKey } from './extractPublicKey'; -import { getCurrentNetwork } from './getCurrentNetwork'; import { getErc20TokenBalance } from './getErc20TokenBalance'; import { getStarkName } from './getStarkName'; import { getStoredErc20Tokens } from './getStoredErc20Tokens'; @@ -48,6 +47,7 @@ import { switchNetwork, getDeploymentData, watchAsset, + getCurrentNetwork, } from './rpcs'; import { sendTransaction } from './sendTransaction'; import { signDeployAccountTransaction } from './signDeployAccountTransaction'; @@ -241,7 +241,7 @@ export const onRpcRequest: OnRpcRequestHandler = async ({ request }) => { ); case 'starkNet_getCurrentNetwork': - return await getCurrentNetwork(apiParams); + return await getCurrentNetwork.execute(null); case 'starkNet_getStoredNetworks': return await getStoredNetworks(apiParams); diff --git a/packages/starknet-snap/src/rpcs/get-current-network.test.ts b/packages/starknet-snap/src/rpcs/get-current-network.test.ts new file mode 100644 index 00000000..655b0fa3 --- /dev/null +++ b/packages/starknet-snap/src/rpcs/get-current-network.test.ts @@ -0,0 +1,38 @@ +import { NetworkStateManager } from '../state/network-state-manager'; +import type { Network } from '../types/snapState'; +import { STARKNET_MAINNET_NETWORK } from '../utils/constants'; +import { getCurrentNetwork } from './get-current-network'; + +jest.mock('../utils/logger'); + +describe('getCurrentNetwork', () => { + const mockNetworkStateManager = ({ + currentNetwork = STARKNET_MAINNET_NETWORK, + }: { + currentNetwork?: Network; + }) => { + const getCurrentNetworkSpy = jest.spyOn( + NetworkStateManager.prototype, + 'getCurrentNetwork', + ); + + getCurrentNetworkSpy.mockResolvedValue(currentNetwork); + + return { getCurrentNetworkSpy }; + }; + + it('return the selected network', async () => { + const currentNetwork = STARKNET_MAINNET_NETWORK; + const { getCurrentNetworkSpy } = mockNetworkStateManager({ + currentNetwork, + }); + + const result = await getCurrentNetwork.execute(null); + + expect(getCurrentNetworkSpy).toHaveBeenCalled(); + expect(result).toStrictEqual({ + name: currentNetwork.name, + chainId: currentNetwork.chainId, + }); + }); +}); diff --git a/packages/starknet-snap/src/rpcs/get-current-network.ts b/packages/starknet-snap/src/rpcs/get-current-network.ts new file mode 100644 index 00000000..eb64b9aa --- /dev/null +++ b/packages/starknet-snap/src/rpcs/get-current-network.ts @@ -0,0 +1,58 @@ +import type { constants } from 'starknet'; +import type { Infer } from 'superstruct'; +import { literal, string, object } from 'superstruct'; + +import { NetworkStateManager } from '../state/network-state-manager'; +import { RpcController, ChainIdStruct } from '../utils'; + +export const GetCurrentNetworkRequestStruct = literal(null); + +export const GetCurrentNetworkResponseStruct = object({ + name: string(), + chainId: ChainIdStruct, +}); + +export type GetCurrentNetworkParams = Infer< + typeof GetCurrentNetworkRequestStruct +>; + +export type GetCurrentNetworkResponse = Infer< + typeof GetCurrentNetworkResponseStruct +>; + +/** + * The RPC handler to get the current network. + */ +export class GetCurrentNetworkRpc extends RpcController< + GetCurrentNetworkParams, + GetCurrentNetworkResponse +> { + protected requestStruct = GetCurrentNetworkRequestStruct; + + protected responseStruct = GetCurrentNetworkResponseStruct; + + /** + * Execute the get the current network. + * + * @param _ + * @returns A promise that resolve to the current network. + */ + async execute( + _: GetCurrentNetworkParams, + ): Promise { + return super.execute(_); + } + + protected async handleRequest( + _: GetCurrentNetworkParams, + ): Promise { + const networkStateMgr = new NetworkStateManager(); + const network = await networkStateMgr.getCurrentNetwork(); + return { + name: network.name, + chainId: network.chainId as unknown as constants.StarknetChainId, + }; + } +} + +export const getCurrentNetwork = new GetCurrentNetworkRpc(); diff --git a/packages/starknet-snap/src/rpcs/index.ts b/packages/starknet-snap/src/rpcs/index.ts index d4bd560e..80556c96 100644 --- a/packages/starknet-snap/src/rpcs/index.ts +++ b/packages/starknet-snap/src/rpcs/index.ts @@ -8,3 +8,4 @@ export * from './verify-signature'; export * from './switch-network'; export * from './get-deployment-data'; export * from './watch-asset'; +export * from './get-current-network'; diff --git a/packages/starknet-snap/test/src/getCurrentNetwork.test.ts b/packages/starknet-snap/test/src/getCurrentNetwork.test.ts deleted file mode 100644 index 112e4e07..00000000 --- a/packages/starknet-snap/test/src/getCurrentNetwork.test.ts +++ /dev/null @@ -1,56 +0,0 @@ -import chai, { expect } from 'chai'; -import sinon from 'sinon'; -import sinonChai from 'sinon-chai'; -import { WalletMock } from '../wallet.mock.test'; -import { SnapState } from '../../src/types/snapState'; -import { - STARKNET_MAINNET_NETWORK, - STARKNET_SEPOLIA_TESTNET_NETWORK, -} from '../../src/utils/constants'; -import { getCurrentNetwork } from '../../src/getCurrentNetwork'; -import { Mutex } from 'async-mutex'; -import { ApiParams } from '../../src/types/snapApi'; - -chai.use(sinonChai); -const sandbox = sinon.createSandbox(); - -describe('Test function: getStoredNetworks', function () { - const walletStub = new WalletMock(); - const state: SnapState = { - accContracts: [], - erc20Tokens: [], - networks: [STARKNET_MAINNET_NETWORK, STARKNET_SEPOLIA_TESTNET_NETWORK], - transactions: [], - currentNetwork: STARKNET_MAINNET_NETWORK, - }; - const apiParams: ApiParams = { - state, - requestParams: {}, - wallet: walletStub, - saveMutex: new Mutex(), - }; - - let stateStub: sinon.SinonStub; - beforeEach(function () { - stateStub = walletStub.rpcStubs.snap_manageState; - stateStub.resolves(state); - }); - - afterEach(function () { - walletStub.reset(); - sandbox.restore(); - }); - - it('should get the current network correctly', async function () { - const result = await getCurrentNetwork(apiParams); - expect(stateStub).not.to.have.been.called; - expect(result).to.be.eql(STARKNET_MAINNET_NETWORK); - }); - - it('should get STARKNET_MAINNET_NETWORK if current network is undefined', async function () { - state.currentNetwork = undefined; - const result = await getCurrentNetwork(apiParams); - expect(stateStub).not.to.have.been.called; - expect(result).to.be.eql(STARKNET_MAINNET_NETWORK); - }); -});