From 167065f994e5b47fc4d04721d8dbe80fda15ea02 Mon Sep 17 00:00:00 2001 From: Petar Penovic Date: Fri, 23 Feb 2024 08:19:22 +0100 Subject: [PATCH] feat: enable base fetch override --- __tests__/rpcChannel.test.ts | 27 +++++++++++++++++------- __tests__/rpcProvider.test.ts | 8 ++++++++ src/channel/rpc_0_6.ts | 32 ++++++++++++++++++++--------- src/channel/rpc_0_7.ts | 32 ++++++++++++++++++++--------- src/types/provider/configuration.ts | 1 + src/utils/responseParser/rpc.ts | 2 -- 6 files changed, 73 insertions(+), 29 deletions(-) diff --git a/__tests__/rpcChannel.test.ts b/__tests__/rpcChannel.test.ts index f82ee7ab5..ac535a543 100644 --- a/__tests__/rpcChannel.test.ts +++ b/__tests__/rpcChannel.test.ts @@ -1,18 +1,31 @@ -import { RPC07 } from '../src'; +import { RPC06, RPC07 } from '../src'; import { createBlockForDevnet, getTestProvider } from './config/fixtures'; import { initializeMatcher } from './config/schema'; -describe('RPC 0.7.0', () => { - const rpcProvider = getTestProvider(false); - const channel = rpcProvider.channel as RPC07.RpcChannel; +describe('RpcChannel', () => { + const { nodeUrl } = getTestProvider(false).channel; + const channel07 = new RPC07.RpcChannel({ nodeUrl }); initializeMatcher(expect); beforeAll(async () => { await createBlockForDevnet(); }); - test('getBlockWithReceipts', async () => { - const response = await channel.getBlockWithReceipts('latest'); - expect(response).toMatchSchemaRef('BlockWithTxReceipts'); + test('baseFetch override', async () => { + const baseFetch = jest.fn(); + const fetchChannel06 = new RPC06.RpcChannel({ nodeUrl, baseFetch }); + const fetchChannel07 = new RPC07.RpcChannel({ nodeUrl, baseFetch }); + (fetchChannel06.fetch as any)(); + expect(baseFetch).toHaveBeenCalledTimes(1); + baseFetch.mockClear(); + (fetchChannel07.fetch as any)(); + expect(baseFetch).toHaveBeenCalledTimes(1); + }); + + describe('RPC 0.7.0', () => { + test('getBlockWithReceipts', async () => { + const response = await channel07.getBlockWithReceipts('latest'); + expect(response).toMatchSchemaRef('BlockWithTxReceipts'); + }); }); }); diff --git a/__tests__/rpcProvider.test.ts b/__tests__/rpcProvider.test.ts index 6c2fe6835..270c7557e 100644 --- a/__tests__/rpcProvider.test.ts +++ b/__tests__/rpcProvider.test.ts @@ -46,6 +46,14 @@ describeIfRpc('RPCProvider', () => { await createBlockForDevnet(); }); + test('baseFetch override', async () => { + const { nodeUrl } = rpcProvider.channel; + const baseFetch = jest.fn(); + const fetchProvider = new RpcProvider({ nodeUrl, baseFetch }); + (fetchProvider.fetch as any)(); + expect(baseFetch.mock.calls.length).toBe(1); + }); + test('getChainId', async () => { const fetchSpy = jest.spyOn(rpcProvider.channel as any, 'fetchEndpoint'); (rpcProvider as any).chainId = undefined as unknown as StarknetChainId; diff --git a/src/channel/rpc_0_6.ts b/src/channel/rpc_0_6.ts index f89ce7eff..8d691d2b3 100644 --- a/src/channel/rpc_0_6.ts +++ b/src/channel/rpc_0_6.ts @@ -40,21 +40,31 @@ export class RpcChannel { public headers: object; - readonly retries: number; - public requestId: number; readonly blockIdentifier: BlockIdentifier; + readonly retries: number; + + readonly waitMode: boolean; // behave like web2 rpc and return when tx is processed + private chainId?: StarknetChainId; private specVersion?: string; - readonly waitMode: Boolean; // behave like web2 rpc and return when tx is processed + private baseFetch: NonNullable; constructor(optionsOrProvider?: RpcProviderOptions) { - const { nodeUrl, retries, headers, blockIdentifier, chainId, specVersion, waitMode } = - optionsOrProvider || {}; + const { + baseFetch, + blockIdentifier, + chainId, + headers, + nodeUrl, + retries, + specVersion, + waitMode, + } = optionsOrProvider || {}; if (Object.values(NetworkName).includes(nodeUrl as NetworkName)) { this.nodeUrl = getDefaultNodeUrl(nodeUrl as NetworkName, optionsOrProvider?.default); } else if (nodeUrl) { @@ -62,12 +72,14 @@ export class RpcChannel { } else { this.nodeUrl = getDefaultNodeUrl(undefined, optionsOrProvider?.default); } - this.retries = retries || defaultOptions.retries; - this.headers = { ...defaultOptions.headers, ...headers }; - this.blockIdentifier = blockIdentifier || defaultOptions.blockIdentifier; + this.baseFetch = baseFetch ?? fetch; + this.blockIdentifier = blockIdentifier ?? defaultOptions.blockIdentifier; this.chainId = chainId; + this.headers = { ...defaultOptions.headers, ...headers }; + this.retries = retries ?? defaultOptions.retries; this.specVersion = specVersion; - this.waitMode = waitMode || false; + this.waitMode = waitMode ?? false; + this.requestId = 0; } @@ -82,7 +94,7 @@ export class RpcChannel { method, ...(params && { params }), }; - return fetch(this.nodeUrl, { + return this.baseFetch(this.nodeUrl, { method: 'POST', body: stringify(rpcRequestBody), headers: this.headers as Record, diff --git a/src/channel/rpc_0_7.ts b/src/channel/rpc_0_7.ts index 741ecd5ef..9fb28861e 100644 --- a/src/channel/rpc_0_7.ts +++ b/src/channel/rpc_0_7.ts @@ -40,21 +40,31 @@ export class RpcChannel { public headers: object; - readonly retries: number; - public requestId: number; readonly blockIdentifier: BlockIdentifier; + readonly retries: number; + + readonly waitMode: boolean; // behave like web2 rpc and return when tx is processed + private chainId?: StarknetChainId; private specVersion?: string; - readonly waitMode: Boolean; // behave like web2 rpc and return when tx is processed + private baseFetch: NonNullable; constructor(optionsOrProvider?: RpcProviderOptions) { - const { nodeUrl, retries, headers, blockIdentifier, chainId, specVersion, waitMode } = - optionsOrProvider || {}; + const { + baseFetch, + blockIdentifier, + chainId, + headers, + nodeUrl, + retries, + specVersion, + waitMode, + } = optionsOrProvider || {}; if (Object.values(NetworkName).includes(nodeUrl as NetworkName)) { this.nodeUrl = getDefaultNodeUrl(nodeUrl as NetworkName, optionsOrProvider?.default); } else if (nodeUrl) { @@ -62,12 +72,14 @@ export class RpcChannel { } else { this.nodeUrl = getDefaultNodeUrl(undefined, optionsOrProvider?.default); } - this.retries = retries || defaultOptions.retries; - this.headers = { ...defaultOptions.headers, ...headers }; - this.blockIdentifier = blockIdentifier || defaultOptions.blockIdentifier; + this.baseFetch = baseFetch ?? fetch; + this.blockIdentifier = blockIdentifier ?? defaultOptions.blockIdentifier; this.chainId = chainId; + this.headers = { ...defaultOptions.headers, ...headers }; + this.retries = retries ?? defaultOptions.retries; this.specVersion = specVersion; - this.waitMode = waitMode || false; + this.waitMode = waitMode ?? false; + this.requestId = 0; } @@ -82,7 +94,7 @@ export class RpcChannel { method, ...(params && { params }), }; - return fetch(this.nodeUrl, { + return this.baseFetch(this.nodeUrl, { method: 'POST', body: stringify(rpcRequestBody), headers: this.headers as Record, diff --git a/src/types/provider/configuration.ts b/src/types/provider/configuration.ts index 71eaf534f..20b5a5f03 100644 --- a/src/types/provider/configuration.ts +++ b/src/types/provider/configuration.ts @@ -12,6 +12,7 @@ export type RpcProviderOptions = { specVersion?: string; default?: boolean; waitMode?: boolean; + baseFetch?: WindowOrWorkerGlobalScope['fetch']; feeMarginPercentage?: { l1BoundMaxAmount: number; l1BoundMaxPricePerUnit: number; diff --git a/src/utils/responseParser/rpc.ts b/src/utils/responseParser/rpc.ts index acce51203..4a006ed94 100644 --- a/src/utils/responseParser/rpc.ts +++ b/src/utils/responseParser/rpc.ts @@ -20,8 +20,6 @@ import { toBigInt } from '../num'; import { isString } from '../shortString'; import { estimateFeeToBounds, estimatedFeeToMaxFee } from '../stark'; import { ResponseParser } from '.'; -import { isString } from '../shortString'; - export class RPCResponseParser implements