diff --git a/extensions/action/src/constants.ts b/extensions/action/src/constants.ts index 9bafbf63..520e4a2b 100644 --- a/extensions/action/src/constants.ts +++ b/extensions/action/src/constants.ts @@ -1,17 +1 @@ -import { - INTERNAL, -} from '@harlem/core'; - -export const SENDER = 'extension:action'; -export const STATE_PROP = `${INTERNAL.prefix}actions` as const; - -export const MUTATIONS = { - init: 'extension:action:init', - register: 'extension:action:register', - incrementRunCount: 'extension:action:increment-run-count', - addInstance: 'extension:action:add-instance', - removeInstance: 'extension:action:remove-instance', - addError: 'extension:action:add-error', - clearErrors: 'extension:action:clear-errors', - resetState: 'extension:action:reset-state', -} as const; \ No newline at end of file +export const SENDER = 'extension:action'; \ No newline at end of file diff --git a/extensions/action/src/index.ts b/extensions/action/src/index.ts index 96f3a2ba..12a79434 100644 --- a/extensions/action/src/index.ts +++ b/extensions/action/src/index.ts @@ -71,49 +71,41 @@ export default function actionExtension(options?: Part return (store: InternalStore) => { store.register('extensions', 'action', () => rootOptions); + const actionTasks = new Map>>(); const actionState = reactive(new Map()); - function setActionState(name: string) { + function setActionState(name: string) { const state = { runCount: 0, - tasks: new Set(), instances: new Map(), errors: new Map(), - } as ActionState; + } as ActionState; actionState.set(name, state); return state; } - function getActionState(name: string) { - return (actionState.get(name) || setActionState(name)) as ActionState; - } - - function updateActionState(name: string, producer: (currentState: ActionState) => Partial) { - const currentState = actionState.get(name); - - if (!currentState) { - return; - } - - const newState = producer(currentState); - - if (newState !== currentState) { - actionState.set(name, { - ...currentState, - ...newState, - }); - } + function getActionState(name: string) { + return (actionState.get(name) || setActionState(name)) as ActionState; } function registerAction(name: string, options: Partial> = {}) { store.register('actions', name, () => options); - return setActionState(name); + setActionState(name); + + const tasks = new Set>(); + actionTasks.set(name, tasks); + + return { + tasks, + }; } function action(name: string, body: ActionBody, options?: Partial>): Action { - registerAction(name, options); + const { + tasks, + } = registerAction(name, options); const { concurrent, @@ -132,16 +124,13 @@ export default function actionExtension(options?: Part } as ActionOptions; const mutate = (mutator: Mutator) => store.write(name, SENDER, mutator); - const incrementRunCount = () => updateActionState(name, ({ runCount }) => ({ - runCount: runCount + 1, - })); + const incrementRunCount = () => getActionState(name).runCount += 1; return ((payload: TPayload, controller?: AbortController) => { const { - tasks, instances, errors, - } = getActionState(name); + } = getActionState(name); if (!concurrent || (typeIsFunction(concurrent) && !concurrent(payload, Array.from(instances.values())))) { abortAction(name, 'New instance started on non-concurrent action'); @@ -264,11 +253,9 @@ export default function actionExtension(options?: Part ([] as string[]) .concat(name) .forEach(name => { - const { - tasks, - } = getActionState(name); + const tasks = actionTasks.get(name); - if (tasks.size) { + if (tasks?.size) { tasks.forEach(task => { task.abort(reason); tasks.delete(task); diff --git a/extensions/action/src/types.ts b/extensions/action/src/types.ts index abb3d6af..129a0401 100644 --- a/extensions/action/src/types.ts +++ b/extensions/action/src/types.ts @@ -26,9 +26,8 @@ export interface ActionAbortStrategies { warn: ActionAbortStrategy; } -export interface ActionState { +export interface ActionState { runCount: number; - tasks: Set>; instances: Map; errors: Map; } diff --git a/extensions/action/test/actions.test.ts b/extensions/action/test/actions.test.ts index 68d3211e..905e74e6 100644 --- a/extensions/action/test/actions.test.ts +++ b/extensions/action/test/actions.test.ts @@ -202,12 +202,19 @@ describe('Actions Extension', () => { }); test('Handles concurrency', async () => { + expect.assertions(3); + const { action, } = instance.store; - const concurrentAction = action('concurrent-action', async () => {}); - const nonConcurrentAction = action('non-concurrent-action', async () => {}, { + const abortFn = vi.fn(); + + const concurrentAction = action('concurrent-action', async () => sleep(100)); + const nonConcurrentAction = action('non-concurrent-action', async (_, __, ___, onAbort) => { + onAbort(abortFn); + return sleep(100); + }, { concurrent: false, }); @@ -234,6 +241,43 @@ describe('Actions Extension', () => { expect(hasConcurrentFailed).toBe(false); expect(hasNonConcurrentFailed).toBe(true); + expect(abortFn).toHaveBeenCalled(); + }); + + test('Handles custom concurrency methods', async () => { + expect.assertions(2); + + const { + action, + } = instance.store; + + const customConcurrentAction = action('concurrent-action', async (payload: number) => sleep(100), { + concurrent: (payload, runningPayloads) => !runningPayloads.includes(payload), + }); + + let hasCustomConcurrentFailed = false; + + try { + await Promise.all([ + customConcurrentAction(1), + customConcurrentAction(2), + ]); + } catch { + hasCustomConcurrentFailed = true; + } + + expect(hasCustomConcurrentFailed).toBe(false); + + try { + await Promise.all([ + customConcurrentAction(2), + customConcurrentAction(2), + ]); + } catch { + hasCustomConcurrentFailed = true; + } + + expect(hasCustomConcurrentFailed).toBe(true); }); test('Handles nested cancellation', async () => { @@ -247,8 +291,8 @@ describe('Actions Extension', () => { const catchAssertion = vi.fn(); - const childAction1 = action('child1', () => sleep(1000)); - const childAction2 = action('child2', () => sleep(1000)); + const childAction1 = action('child1', () => sleep(100)); + const childAction2 = action('child2', () => sleep(100)); const parentAction = action('parent', (_, __, controller) => Promise.all([ childAction1(_, controller), childAction2(_, controller),