diff --git a/extensions/action/test/actions.test.ts b/extensions/action/test/actions.test.ts index 69bdddd..905e74e 100644 --- a/extensions/action/test/actions.test.ts +++ b/extensions/action/test/actions.test.ts @@ -210,9 +210,10 @@ describe('Actions Extension', () => { const abortFn = vi.fn(); - const concurrentAction = action('concurrent-action', async () => {}); + const concurrentAction = action('concurrent-action', async () => sleep(100)); const nonConcurrentAction = action('non-concurrent-action', async (_, __, ___, onAbort) => { onAbort(abortFn); + return sleep(100); }, { concurrent: false, }); @@ -243,6 +244,42 @@ describe('Actions Extension', () => { expect(abortFn).toHaveBeenCalled(); }); + test('Handles custom concurrency methods', async () => { + expect.assertions(2); + + const { + action, + } = instance.store; + + const customConcurrentAction = action('concurrent-action', async (payload: number) => sleep(100), { + concurrent: (payload, runningPayloads) => !runningPayloads.includes(payload), + }); + + let hasCustomConcurrentFailed = false; + + try { + await Promise.all([ + customConcurrentAction(1), + customConcurrentAction(2), + ]); + } catch { + hasCustomConcurrentFailed = true; + } + + expect(hasCustomConcurrentFailed).toBe(false); + + try { + await Promise.all([ + customConcurrentAction(2), + customConcurrentAction(2), + ]); + } catch { + hasCustomConcurrentFailed = true; + } + + expect(hasCustomConcurrentFailed).toBe(true); + }); + test('Handles nested cancellation', async () => { expect.assertions(7); @@ -254,8 +291,8 @@ describe('Actions Extension', () => { const catchAssertion = vi.fn(); - const childAction1 = action('child1', () => sleep(1000)); - const childAction2 = action('child2', () => sleep(1000)); + const childAction1 = action('child1', () => sleep(100)); + const childAction2 = action('child2', () => sleep(100)); const parentAction = action('parent', (_, __, controller) => Promise.all([ childAction1(_, controller), childAction2(_, controller),