Skip to content

Commit

Permalink
Merge pull request #71 from andrewcourtice/fix/action-concurrency
Browse files Browse the repository at this point in the history
Fix action concurrency getting wrong task reference
  • Loading branch information
andrewcourtice authored Dec 12, 2022
2 parents 01b59a0 + e1fe050 commit eaad45a
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 56 deletions.
18 changes: 1 addition & 17 deletions extensions/action/src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1 @@
import {
INTERNAL,
} from '@harlem/core';

export const SENDER = 'extension:action';
export const STATE_PROP = `${INTERNAL.prefix}actions` as const;

export const MUTATIONS = {
init: 'extension:action:init',
register: 'extension:action:register',
incrementRunCount: 'extension:action:increment-run-count',
addInstance: 'extension:action:add-instance',
removeInstance: 'extension:action:remove-instance',
addError: 'extension:action:add-error',
clearErrors: 'extension:action:clear-errors',
resetState: 'extension:action:reset-state',
} as const;
export const SENDER = 'extension:action';
53 changes: 20 additions & 33 deletions extensions/action/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,49 +71,41 @@ export default function actionExtension<TState extends BaseState>(options?: Part
return (store: InternalStore<TState>) => {
store.register('extensions', 'action', () => rootOptions);

const actionTasks = new Map<string, Set<Task<unknown>>>();
const actionState = reactive(new Map<string, ActionState>());

function setActionState<TPayload = unknown, TResult = unknown>(name: string) {
function setActionState<TPayload = unknown>(name: string) {
const state = {
runCount: 0,
tasks: new Set(),
instances: new Map(),
errors: new Map(),
} as ActionState<TPayload, TResult>;
} as ActionState<TPayload>;

actionState.set(name, state);

return state;
}

function getActionState<TPayload = unknown, TResult = unknown>(name: string) {
return (actionState.get(name) || setActionState(name)) as ActionState<TPayload, TResult>;
}

function updateActionState(name: string, producer: (currentState: ActionState) => Partial<ActionState>) {
const currentState = actionState.get(name);

if (!currentState) {
return;
}

const newState = producer(currentState);

if (newState !== currentState) {
actionState.set(name, {
...currentState,
...newState,
});
}
function getActionState<TPayload = unknown>(name: string) {
return (actionState.get(name) || setActionState(name)) as ActionState<TPayload>;
}

function registerAction(name: string, options: Partial<ActionOptions<any>> = {}) {
store.register('actions', name, () => options);
return setActionState(name);
setActionState(name);

const tasks = new Set<Task<unknown>>();
actionTasks.set(name, tasks);

return {
tasks,
};
}

function action<TPayload, TResult = void>(name: string, body: ActionBody<TState, TPayload, TResult>, options?: Partial<ActionOptions<TPayload>>): Action<TPayload, TResult> {
registerAction(name, options);
const {
tasks,
} = registerAction(name, options);

const {
concurrent,
Expand All @@ -132,16 +124,13 @@ export default function actionExtension<TState extends BaseState>(options?: Part
} as ActionOptions<TPayload>;

const mutate = (mutator: Mutator<TState, undefined, void>) => store.write(name, SENDER, mutator);
const incrementRunCount = () => updateActionState(name, ({ runCount }) => ({
runCount: runCount + 1,
}));
const incrementRunCount = () => getActionState(name).runCount += 1;

return ((payload: TPayload, controller?: AbortController) => {
const {
tasks,
instances,
errors,
} = getActionState<TPayload, TResult>(name);
} = getActionState<TPayload>(name);

if (!concurrent || (typeIsFunction(concurrent) && !concurrent(payload, Array.from(instances.values())))) {
abortAction(name, 'New instance started on non-concurrent action');
Expand Down Expand Up @@ -264,11 +253,9 @@ export default function actionExtension<TState extends BaseState>(options?: Part
([] as string[])
.concat(name)
.forEach(name => {
const {
tasks,
} = getActionState(name);
const tasks = actionTasks.get(name);

if (tasks.size) {
if (tasks?.size) {
tasks.forEach(task => {
task.abort(reason);
tasks.delete(task);
Expand Down
3 changes: 1 addition & 2 deletions extensions/action/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ export interface ActionAbortStrategies {
warn: ActionAbortStrategy;
}

export interface ActionState<TPayload = unknown, TResult = unknown> {
export interface ActionState<TPayload = unknown> {
runCount: number;
tasks: Set<Task<TResult>>;
instances: Map<symbol, TPayload>;
errors: Map<symbol, unknown>;
}
Expand Down
52 changes: 48 additions & 4 deletions extensions/action/test/actions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,19 @@ describe('Actions Extension', () => {
});

test('Handles concurrency', async () => {
expect.assertions(3);

const {
action,
} = instance.store;

const concurrentAction = action('concurrent-action', async () => {});
const nonConcurrentAction = action('non-concurrent-action', async () => {}, {
const abortFn = vi.fn();

const concurrentAction = action('concurrent-action', async () => sleep(100));
const nonConcurrentAction = action('non-concurrent-action', async (_, __, ___, onAbort) => {
onAbort(abortFn);
return sleep(100);
}, {
concurrent: false,
});

Expand All @@ -234,6 +241,43 @@ describe('Actions Extension', () => {

expect(hasConcurrentFailed).toBe(false);
expect(hasNonConcurrentFailed).toBe(true);
expect(abortFn).toHaveBeenCalled();
});

test('Handles custom concurrency methods', async () => {
expect.assertions(2);

const {
action,
} = instance.store;

const customConcurrentAction = action('concurrent-action', async (payload: number) => sleep(100), {
concurrent: (payload, runningPayloads) => !runningPayloads.includes(payload),
});

let hasCustomConcurrentFailed = false;

try {
await Promise.all([
customConcurrentAction(1),
customConcurrentAction(2),
]);
} catch {
hasCustomConcurrentFailed = true;
}

expect(hasCustomConcurrentFailed).toBe(false);

try {
await Promise.all([
customConcurrentAction(2),
customConcurrentAction(2),
]);
} catch {
hasCustomConcurrentFailed = true;
}

expect(hasCustomConcurrentFailed).toBe(true);
});

test('Handles nested cancellation', async () => {
Expand All @@ -247,8 +291,8 @@ describe('Actions Extension', () => {

const catchAssertion = vi.fn();

const childAction1 = action('child1', () => sleep(1000));
const childAction2 = action('child2', () => sleep(1000));
const childAction1 = action('child1', () => sleep(100));
const childAction2 = action('child2', () => sleep(100));
const parentAction = action('parent', (_, __, controller) => Promise.all([
childAction1(_, controller),
childAction2(_, controller),
Expand Down

1 comment on commit eaad45a

@vercel
Copy link

@vercel vercel bot commented on eaad45a Dec 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.