diff --git a/src/SmartTransactionsController.test.ts b/src/SmartTransactionsController.test.ts index 5b1f4d6..24290c0 100644 --- a/src/SmartTransactionsController.test.ts +++ b/src/SmartTransactionsController.test.ts @@ -31,7 +31,7 @@ import type { SmartTransactionsControllerEvents, } from './SmartTransactionsController'; import type { SmartTransaction, UnsignedTransaction, Hex } from './types'; -import { SmartTransactionStatuses } from './types'; +import { SmartTransactionStatuses, ClientId } from './types'; import * as utils from './utils'; jest.mock('@ethersproject/bytes', () => ({ @@ -1214,6 +1214,170 @@ describe('SmartTransactionsController', () => { }, ); }); + + it('calls updateTransaction when smart transaction is cancelled and returnTxHashAsap is true', async () => { + const mockUpdateTransaction = jest.fn(); + const defaultState = getDefaultSmartTransactionsControllerState(); + const pendingStx = createStateAfterPending(); + await withController( + { + options: { + updateTransaction: mockUpdateTransaction, + getFeatureFlags: () => ({ + smartTransactions: { + mobileReturnTxHashAsap: true, + }, + }), + getTransactions: () => [ + { + id: 'test-tx-id', + status: TransactionStatus.submitted, + chainId: '0x1', + time: 123, + txParams: { + from: '0x123', + }, + }, + ], + state: { + smartTransactionsState: { + ...defaultState.smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: pendingStx as SmartTransaction[], + }, + }, + }, + }, + }, + async ({ controller }) => { + const smartTransaction = { + uuid: 'uuid1', + status: SmartTransactionStatuses.CANCELLED, + transactionId: 'test-tx-id', + }; + + controller.updateSmartTransaction(smartTransaction); + + expect(mockUpdateTransaction).toHaveBeenCalledWith( + { + id: 'test-tx-id', + status: TransactionStatus.failed, + chainId: '0x1', + time: 123, + txParams: { + from: '0x123', + }, + }, + 'Smart transaction cancelled', + ); + }, + ); + }); + + it('does not call updateTransaction when smart transaction is cancelled but returnTxHashAsap is false', async () => { + const mockUpdateTransaction = jest.fn(); + await withController( + { + options: { + updateTransaction: mockUpdateTransaction, + getFeatureFlags: () => ({ + smartTransactions: { + mobileReturnTxHashAsap: false, + }, + }), + getTransactions: () => [ + { + id: 'test-tx-id', + status: TransactionStatus.submitted, + chainId: '0x1', + time: 123, + txParams: { + from: '0x123', + }, + }, + ], + }, + }, + async ({ controller }) => { + const smartTransaction = { + uuid: 'test-uuid', + status: SmartTransactionStatuses.CANCELLED, + transactionId: 'test-tx-id', + }; + + controller.updateSmartTransaction(smartTransaction); + + expect(mockUpdateTransaction).not.toHaveBeenCalled(); + }, + ); + }); + + it('does not call updateTransaction when transaction is not found in regular transactions', async () => { + const mockUpdateTransaction = jest.fn(); + + await withController( + { + options: { + updateTransaction: mockUpdateTransaction, + getFeatureFlags: () => ({ + smartTransactions: { + mobileReturnTxHashAsap: true, + }, + }), + getTransactions: () => [], + }, + }, + async ({ controller }) => { + const smartTransaction = { + uuid: 'test-uuid', + status: SmartTransactionStatuses.CANCELLED, + transactionId: 'test-tx-id', + }; + + controller.updateSmartTransaction(smartTransaction); + + expect(mockUpdateTransaction).not.toHaveBeenCalled(); + }, + ); + }); + + it('does not call updateTransaction for non-cancelled transactions', async () => { + const mockUpdateTransaction = jest.fn(); + await withController( + { + options: { + updateTransaction: mockUpdateTransaction, + getFeatureFlags: () => ({ + smartTransactions: { + mobileReturnTxHashAsap: true, + }, + }), + getTransactions: () => [ + { + id: 'test-tx-id', + status: TransactionStatus.submitted, + chainId: '0x1', + time: 123, + txParams: { + from: '0x123', + }, + }, + ], + }, + }, + async ({ controller }) => { + const smartTransaction = { + uuid: 'test-uuid', + status: SmartTransactionStatuses.PENDING, + transactionId: 'test-tx-id', + }; + + controller.updateSmartTransaction(smartTransaction); + + expect(mockUpdateTransaction).not.toHaveBeenCalled(); + }, + ); + }); }); describe('cancelSmartTransaction', () => { @@ -1438,7 +1602,7 @@ describe('SmartTransactionsController', () => { const fetchHeaders = { headers: { 'Content-Type': 'application/json', - 'X-Client-Id': 'default', + 'X-Client-Id': ClientId.Mobile, }, }; @@ -1813,6 +1977,7 @@ async function withController( const controller = new SmartTransactionsController({ messenger, + clientId: ClientId.Mobile, getNonceLock: jest.fn().mockResolvedValue({ nextNonce: 'nextNonce', releaseLock: jest.fn(), @@ -1827,6 +1992,8 @@ async function withController( deviceModel: 'ledger', }); }), + getFeatureFlags: jest.fn(), + updateTransaction: jest.fn(), ...options, }); diff --git a/src/SmartTransactionsController.ts b/src/SmartTransactionsController.ts index 1251d86..98391db 100644 --- a/src/SmartTransactionsController.ts +++ b/src/SmartTransactionsController.ts @@ -38,6 +38,8 @@ import type { UnsignedTransaction, GetTransactionsOptions, MetaMetricsProps, + FeatureFlags, + ClientId, } from './types'; import { APIType, SmartTransactionStatuses } from './types'; import { @@ -53,11 +55,11 @@ import { getTxHash, getSmartTransactionMetricsProperties, getSmartTransactionMetricsSensitiveProperties, + getReturnTxHashAsap, } from './utils'; const SECOND = 1000; export const DEFAULT_INTERVAL = SECOND * 5; -const DEFAULT_CLIENT_ID = 'default'; const ETH_QUERY_ERROR_MSG = '`ethQuery` is not defined on SmartTransactionsController'; @@ -178,7 +180,7 @@ export type SmartTransactionsControllerMessenger = type SmartTransactionsControllerOptions = { interval?: number; - clientId?: string; + clientId: ClientId; chainId?: Hex; supportedChainIds?: Hex[]; getNonceLock: TransactionController['getNonceLock']; @@ -198,6 +200,8 @@ type SmartTransactionsControllerOptions = { messenger: SmartTransactionsControllerMessenger; getTransactions: (options?: GetTransactionsOptions) => TransactionMeta[]; getMetaMetricsProps: () => Promise; + getFeatureFlags: () => FeatureFlags; + updateTransaction: (transaction: TransactionMeta, note: string) => void; }; export type SmartTransactionsControllerPollingInput = { @@ -211,7 +215,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo > { #interval: number; - #clientId: string; + #clientId: ClientId; #chainId: Hex; @@ -233,6 +237,10 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo readonly #getMetaMetricsProps: () => Promise; + #getFeatureFlags: SmartTransactionsControllerOptions['getFeatureFlags']; + + #updateTransaction: SmartTransactionsControllerOptions['updateTransaction']; + /* istanbul ignore next */ async #fetch(request: string, options?: RequestInit) { const fetchOptions = { @@ -248,7 +256,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo constructor({ interval = DEFAULT_INTERVAL, - clientId = DEFAULT_CLIENT_ID, + clientId, chainId: InitialChainId = ChainId.mainnet, supportedChainIds = [ChainId.mainnet, ChainId.sepolia], getNonceLock, @@ -258,6 +266,8 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo messenger, getTransactions, getMetaMetricsProps, + getFeatureFlags, + updateTransaction, }: SmartTransactionsControllerOptions) { super({ name: controllerName, @@ -279,6 +289,8 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo this.#getRegularTransactions = getTransactions; this.#trackMetaMetricsEvent = trackMetaMetricsEvent; this.#getMetaMetricsProps = getMetaMetricsProps; + this.#getFeatureFlags = getFeatureFlags; + this.#updateTransaction = updateTransaction; this.initializeSmartTransactionsForChainId(); @@ -530,24 +542,47 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo return; } + const currentSmartTransaction = currentSmartTransactions[currentIndex]; + const nextSmartTransaction = { + ...currentSmartTransaction, + ...smartTransaction, + }; + // We have to emit this event here, because then a txHash is returned to the TransactionController once it's available // and the #doesTransactionNeedConfirmation function will work properly, since it will find the txHash in the regular transactions list. this.messagingSystem.publish( `SmartTransactionsController:smartTransaction`, - smartTransaction, + nextSmartTransaction, ); + if (nextSmartTransaction.status === SmartTransactionStatuses.CANCELLED) { + const returnTxHashAsap = getReturnTxHashAsap( + this.#clientId, + this.#getFeatureFlags()?.smartTransactions, + ); + if (returnTxHashAsap && nextSmartTransaction.transactionId) { + const foundTransaction = this.#getRegularTransactions().find( + (transaction) => + transaction.id === nextSmartTransaction.transactionId, + ); + if (foundTransaction) { + const updatedTransaction = { + ...foundTransaction, + status: TransactionStatus.failed, + }; + this.#updateTransaction( + updatedTransaction as TransactionMeta, + 'Smart transaction cancelled', + ); + } + } + } + if ( (smartTransaction.status === SmartTransactionStatuses.SUCCESS || smartTransaction.status === SmartTransactionStatuses.REVERTED) && !smartTransaction.confirmed ) { - // confirm smart transaction - const currentSmartTransaction = currentSmartTransactions[currentIndex]; - const nextSmartTransaction = { - ...currentSmartTransaction, - ...smartTransaction, - }; await this.#confirmSmartTransaction(nextSmartTransaction, { chainId, ethQuery, @@ -892,6 +927,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo txHash: submitTransactionResponse.txHash, cancellable: true, type: transactionMeta?.type ?? 'swap', + transactionId: transactionMeta?.id, }, { chainId, ethQuery }, ); diff --git a/src/index.test.ts b/src/index.test.ts index 3490185..d382bb3 100644 --- a/src/index.test.ts +++ b/src/index.test.ts @@ -8,6 +8,7 @@ import SmartTransactionsController, { type AllowedActions, type AllowedEvents, } from './SmartTransactionsController'; +import { ClientId } from './types'; describe('default export', () => { it('exports SmartTransactionsController', () => { @@ -30,6 +31,9 @@ describe('default export', () => { getMetaMetricsProps: jest.fn(async () => { return Promise.resolve({}); }), + getFeatureFlags: jest.fn(), + updateTransaction: jest.fn(), + clientId: ClientId.Extension, }); expect(controller).toBeInstanceOf(SmartTransactionsController); jest.clearAllTimers(); diff --git a/src/types.ts b/src/types.ts index d8b774e..10f124d 100644 --- a/src/types.ts +++ b/src/types.ts @@ -44,6 +44,11 @@ export enum SmartTransactionStatuses { RESOLVED = 'resolved', } +export enum ClientId { + Mobile = 'mobile', + Extension = 'extension', +} + export const cancellationReasonToStatusMap = { [SmartTransactionCancellationReason.WOULD_REVERT]: SmartTransactionStatuses.CANCELLED_WOULD_REVERT, @@ -97,6 +102,7 @@ export type SmartTransaction = { accountHardwareType?: string; accountType?: string; deviceModel?: string; + transactionId?: string; // It's an ID for a regular transaction from the TransactionController. }; export type Fee = { @@ -140,3 +146,10 @@ export type MetaMetricsProps = { accountType?: string; deviceModel?: string; }; + +export type FeatureFlags = { + smartTransactions?: { + mobileReturnTxHashAsap?: boolean; + extensionReturnTxHashAsap?: boolean; + }; +}; diff --git a/src/utils.test.ts b/src/utils.test.ts index 773df5a..4997e9d 100644 --- a/src/utils.test.ts +++ b/src/utils.test.ts @@ -7,6 +7,7 @@ import { APIType, SmartTransactionStatuses, SmartTransactionCancellationReason, + ClientId, } from './types'; import * as utils from './utils'; @@ -48,6 +49,20 @@ describe('src/utils.js', () => { ); }); + it('returns correct URL for ESTIMATE_GAS', () => { + const chainId = '0x1'; // Mainnet in hex + const expectedUrl = `${API_BASE_URL}/networks/1/estimateGas`; + const result = utils.getAPIRequestURL(APIType.ESTIMATE_GAS, chainId); + expect(result).toBe(expectedUrl); + }); + + it('converts hex chainId to decimal for ESTIMATE_GAS', () => { + const chainId = '0x89'; // Polygon in hex (137 in decimal) + const expectedUrl = `${API_BASE_URL}/networks/137/estimateGas`; + const result = utils.getAPIRequestURL(APIType.ESTIMATE_GAS, chainId); + expect(result).toBe(expectedUrl); + }); + it('returns a URL for submitting transactions', () => { expect( utils.getAPIRequestURL(APIType.SUBMIT_TRANSACTIONS, ChainId.mainnet), @@ -294,4 +309,22 @@ describe('src/utils.js', () => { }).toThrow('kzg instance required to instantiate blob tx'); }); }); + + describe('getReturnTxHashAsap', () => { + it('returns extensionReturnTxHashAsap value for Extension client', () => { + const result = utils.getReturnTxHashAsap(ClientId.Extension, { + extensionReturnTxHashAsap: true, + mobileReturnTxHashAsap: false, + }); + expect(result).toBe(true); + }); + + it('returns mobileReturnTxHashAsap value for Mobile client', () => { + const result = utils.getReturnTxHashAsap(ClientId.Mobile, { + extensionReturnTxHashAsap: false, + mobileReturnTxHashAsap: true, + }); + expect(result).toBe(true); + }); + }); }); diff --git a/src/utils.ts b/src/utils.ts index 000e267..058c9d1 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -11,13 +11,18 @@ import _ from 'lodash'; // @ts-ignore import packageJson from '../package.json'; import { API_BASE_URL, SENTINEL_API_BASE_URL_MAP } from './constants'; -import type { SmartTransaction, SmartTransactionsStatus } from './types'; +import type { + SmartTransaction, + SmartTransactionsStatus, + FeatureFlags, +} from './types'; import { APIType, SmartTransactionStatuses, SmartTransactionCancellationReason, SmartTransactionMinedTx, cancellationReasonToStatusMap, + ClientId, } from './types'; export function isSmartTransactionPending(smartTransaction: SmartTransaction) { @@ -264,3 +269,12 @@ export const getSmartTransactionMetricsSensitiveProperties = ( device_model: smartTransaction.deviceModel, }; }; + +export const getReturnTxHashAsap = ( + clientId: ClientId, + smartTransactionsFeatureFlags: FeatureFlags['smartTransactions'], +) => { + return clientId === ClientId.Extension + ? smartTransactionsFeatureFlags?.extensionReturnTxHashAsap + : smartTransactionsFeatureFlags?.mobileReturnTxHashAsap; +};