Skip to content

Commit

Permalink
MERISK-1179: Add ability to reschedule tasks (#199)
Browse files Browse the repository at this point in the history
Signed-off-by: Mariia Maksimova <[email protected]>
  • Loading branch information
mariia-maksimova authored Jul 3, 2024
1 parent 9d3af41 commit 69aa3b8
Show file tree
Hide file tree
Showing 8 changed files with 295 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package com.transferwise.tasks.testapp;

import static com.transferwise.tasks.domain.TaskStatus.WAITING;
import static org.awaitility.Awaitility.await;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.transferwise.common.baseutils.UuidUtils;
import com.transferwise.tasks.BaseIntTest;
import com.transferwise.tasks.ITaskDataSerializer;
import com.transferwise.tasks.ITasksService;
import com.transferwise.tasks.ITasksService.RescheduleTaskResponse.Result;
import com.transferwise.tasks.dao.ITaskDao;
import com.transferwise.tasks.domain.Task;
import com.transferwise.tasks.domain.TaskStatus;
import com.transferwise.tasks.management.ITasksManagementService;
import com.transferwise.tasks.test.ITestTasksService;
import io.micrometer.core.instrument.Counter;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.UUID;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.springframework.beans.factory.annotation.Autowired;

@Slf4j
public class TaskReschedulingIntTest extends BaseIntTest {

@Autowired
private ITasksService tasksService;
@Autowired
private ITestTasksService testTasksService;
@Autowired
private ITaskDataSerializer taskDataSerializer;
@Autowired
private ITasksManagementService tasksManagementService;
@Autowired
private ITaskDao taskDao;

@BeforeEach
void setup() {
transactionsHelper.withTransaction().asNew().call(() -> {
testTasksService.reset();
return null;
});
}

@Test
void taskCanBeSuccessfullyRescheduled() {
testTaskHandlerAdapter.setProcessor(resultRegisteringSyncTaskProcessor);
UUID taskId = UuidUtils.generatePrefixCombUuid();

transactionsHelper.withTransaction().asNew().call(() ->
tasksService.addTask(new ITasksService.AddTaskRequest()
.setTaskId(taskId)
.setData(taskDataSerializer.serialize("I want to be rescheduled"))
.setType("test").setRunAfterTime(ZonedDateTime.now().plusHours(1)))
);

await().until(() -> !testTasksService.getWaitingTasks("test", null).isEmpty());

var task = tasksManagementService.getTasksById(
new ITasksManagementService.GetTasksByIdRequest().setTaskIds(List.of(taskId))
).getTasks().stream().filter(t -> t.getTaskVersionId().getId().equals(taskId)).findFirst().orElseThrow();

assertTrue(transactionsHelper.withTransaction().asNew().call(() ->
tasksService.rescheduleTask(
new ITasksService.RescheduleTaskRequest()
.setTaskId(taskId)
.setVersion(task.getTaskVersionId().getVersion())
.setRunAfterTime(ZonedDateTime.now().minusHours(1))
).getResult() == Result.OK
));

await().until(() -> testTasksService.getTasks("test", null, WAITING).isEmpty());
await().until(() -> resultRegisteringSyncTaskProcessor.getTaskResults().get(taskId) != null);
assertEquals(0, getFailedNextEventTimeChangeCount());
assertEquals(1, getTaskRescheduledCount());
}

@Test
void taskWillNotBeRescheduleIfVersionHasAlreadyChanged() {
testTaskHandlerAdapter.setProcessor(resultRegisteringSyncTaskProcessor);
final long initialFailedNextEventTimeChangeCount = getFailedNextEventTimeChangeCount();
final UUID taskId = UuidUtils.generatePrefixCombUuid();

transactionsHelper.withTransaction().asNew().call(() ->
tasksService.addTask(new ITasksService.AddTaskRequest()
.setTaskId(taskId)
.setData(taskDataSerializer.serialize("I want to be rescheduled too!"))
.setType("test").setRunAfterTime(ZonedDateTime.now().plusHours(1)))
);

await().until(() -> !testTasksService.getWaitingTasks("test", null).isEmpty());

var task = tasksManagementService.getTasksById(
new ITasksManagementService.GetTasksByIdRequest().setTaskIds(List.of(taskId))
).getTasks().stream().filter(t -> t.getTaskVersionId().getId().equals(taskId)).findFirst().orElseThrow();

assertFalse(
transactionsHelper.withTransaction().asNew().call(() ->
tasksService.rescheduleTask(
new ITasksService.RescheduleTaskRequest()
.setTaskId(taskId)
.setVersion(task.getTaskVersionId().getVersion() - 1)
.setRunAfterTime(ZonedDateTime.now().plusHours(2))
).getResult() == Result.OK
)
);
assertEquals(initialFailedNextEventTimeChangeCount + 1, getFailedNextEventTimeChangeCount());
assertEquals(0, getTaskRescheduledCount());
}

@ParameterizedTest
@EnumSource(value = TaskStatus.class,
names = {"WAITING", "UNKNOWN"},
mode = EnumSource.Mode.EXCLUDE)
void taskWillNotBeRescheduleIfNotWaiting(TaskStatus status) {
testTaskHandlerAdapter.setProcessor(resultRegisteringSyncTaskProcessor);
final long initialFailedNextEventTimeChangeCount = getFailedNextEventTimeChangeCount();
final UUID taskId = UuidUtils.generatePrefixCombUuid();

transactionsHelper.withTransaction().asNew().call(() ->
tasksService.addTask(new ITasksService.AddTaskRequest()
.setTaskId(taskId)
.setData(taskDataSerializer.serialize("I do not want to be rescheduled!"))
.setType("test").setRunAfterTime(ZonedDateTime.now().plusHours(2)))
);

await().until(() -> !testTasksService.getWaitingTasks("test", null).isEmpty());
List<Task> tasks = testTasksService.getWaitingTasks("test", null);
Task task = tasks.stream().filter(t -> t.getId().equals(taskId)).findFirst().orElseThrow();

transactionsHelper.withTransaction().asNew().call(() ->
tasksService.resumeTask(new ITasksService.ResumeTaskRequest().setTaskId(taskId).setVersion(task.getVersion()))
);

await().until(() -> testTasksService.getWaitingTasks("test", null).isEmpty());

var updateTask = tasksManagementService.getTasksById(
new ITasksManagementService.GetTasksByIdRequest().setTaskIds(List.of(taskId))
).getTasks().stream().filter(t -> t.getTaskVersionId().getId().equals(taskId)).findFirst().orElseThrow();

taskDao.setStatus(taskId, status, updateTask.getTaskVersionId().getVersion());

var finalTask = tasksManagementService.getTasksById(
new ITasksManagementService.GetTasksByIdRequest().setTaskIds(List.of(taskId))
).getTasks().stream().filter(t -> t.getTaskVersionId().getId().equals(taskId)).findFirst().orElseThrow();

assertFalse(
transactionsHelper.withTransaction().asNew().call(() ->
tasksService.rescheduleTask(
new ITasksService.RescheduleTaskRequest()
.setTaskId(taskId)
.setVersion(finalTask.getTaskVersionId().getVersion())
.setRunAfterTime(ZonedDateTime.now().plusHours(2))
).getResult() == Result.OK
)
);
assertEquals(initialFailedNextEventTimeChangeCount + 1, getFailedNextEventTimeChangeCount());
assertEquals(0, getTaskRescheduledCount());
}

private long getFailedNextEventTimeChangeCount() {
Counter counter = meterRegistry.find("twTasks.tasks.failedNextEventTimeChangeCount").tags(
"taskType", "test"
).counter();

if (counter == null) {
return 0;
} else {
return (long) counter.count();
}
}

private long getTaskRescheduledCount() {
Counter counter = meterRegistry.find("twTasks.tasks.rescheduledCount").tags(
"taskType", "test"
).counter();

if (counter == null) {
return 0;
} else {
return (long) counter.count();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,32 @@ class ResumeTaskRequest {
private boolean force;
}

/**
* Reschedules a task in WAITING state. It is useful, when you want to change the next time the task is executed.
*
* <p>If the task in another state NOT_ALLOWED would be returned.
*/
RescheduleTaskResponse rescheduleTask(RescheduleTaskRequest request);

@Data
@Accessors(chain = true)
class RescheduleTaskRequest {
private UUID taskId;
private long version;
private ZonedDateTime runAfterTime;
}

@Data
@Accessors(chain = true)
class RescheduleTaskResponse {
private UUID taskId;
private Result result;

public enum Result {
OK, NOT_FOUND, NOT_ALLOWED, FAILED
}
}

void startTasksProcessing(String bucketId);

Future<Void> stopTasksProcessing(String bucketId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import com.transferwise.common.context.TwContextClockHolder;
import com.transferwise.common.context.UnitOfWorkManager;
import com.transferwise.common.gracefulshutdown.GracefulShutdownStrategy;
import com.transferwise.tasks.ITasksService.RescheduleTaskResponse.Result;
import com.transferwise.tasks.dao.ITaskDao;
import com.transferwise.tasks.domain.BaseTask;
import com.transferwise.tasks.domain.BaseTask1;
import com.transferwise.tasks.domain.FullTaskRecord;
import com.transferwise.tasks.domain.TaskStatus;
import com.transferwise.tasks.entrypoints.EntryPoint;
import com.transferwise.tasks.entrypoints.EntryPointsGroups;
Expand Down Expand Up @@ -189,6 +191,47 @@ public boolean resumeTask(ResumeTaskRequest request) {
});
}

@Override
@EntryPoint(usesExisting = true)
@Transactional(rollbackFor = Exception.class)
public RescheduleTaskResponse rescheduleTask(RescheduleTaskRequest request) {
return entryPointsHelper.continueOrCreate(EntryPointsGroups.TW_TASKS_ENGINE, EntryPointsNames.RESCHEDULE_TASK,
() -> {
UUID taskId = request.getTaskId();
mdcService.put(request.getTaskId(), request.getVersion());

FullTaskRecord task = taskDao.getTask(taskId, FullTaskRecord.class);

if (task == null) {
log.debug("Cannot reschedule task '" + taskId + "' as it was not found.");
return new RescheduleTaskResponse().setResult(Result.NOT_FOUND).setTaskId(taskId);
}

mdcService.put(task);

long version = task.getVersion();

if (version != request.getVersion()) {
coreMetricsTemplate.registerFailedNextEventTimeChange(task.getType(), task.getNextEventTime(), request.getRunAfterTime());
log.debug("Expected version " + request.getVersion() + " does not match " + version + ".");
return new RescheduleTaskResponse().setResult(Result.NOT_FOUND).setTaskId(taskId);
}

if (task.getStatus().equals(TaskStatus.WAITING.name())) {
if (!taskDao.setNextEventTime(taskId, request.getRunAfterTime(), version, TaskStatus.WAITING.name())) {
coreMetricsTemplate.registerFailedNextEventTimeChange(task.getType(), task.getNextEventTime(), request.getRunAfterTime());
return new RescheduleTaskResponse().setResult(RescheduleTaskResponse.Result.FAILED).setTaskId(taskId);
} else {
coreMetricsTemplate.registerTaskRescheduled(null, task.getType());
return new RescheduleTaskResponse().setResult(RescheduleTaskResponse.Result.OK).setTaskId(taskId);
}
}

coreMetricsTemplate.registerFailedNextEventTimeChange(task.getType(), task.getNextEventTime(), request.getRunAfterTime());
return new RescheduleTaskResponse().setResult(RescheduleTaskResponse.Result.NOT_ALLOWED).setTaskId(taskId);
});
}

@Override
public void startTasksProcessing(String bucketId) {
tasksExecutionTriggerer.startTasksProcessing(bucketId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class DeleteFinishedOldTasksResult {

boolean setStatus(UUID taskId, TaskStatus status, long version);

boolean setNextEventTime(UUID taskId, ZonedDateTime nextEventTime, long version, String state);

boolean markAsSubmitted(UUID taskId, long version, ZonedDateTime maxStuckTime);

Long getTaskVersion(UUID id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ public JdbcTaskDao(DataSource dataSource, ITaskSqlMapper sqlMapper) {
protected String grabForProcessingWithStatusAssertionSql;
protected String grabForProcessingSql;
protected String setStatusSql;
protected String setNextEventTimeSql;
protected String getStuckTasksSql;
protected String prepareStuckOnProcessingTaskForResumingSql;
protected String prepareStuckOnProcessingTaskForResumingSql1;
Expand Down Expand Up @@ -157,6 +158,8 @@ public void afterPropertiesSet() {
+ ",processing_start_time=?,next_event_time=?,processing_tries_count=processing_tries_count+1"
+ ",state_time=?,time_updated=?,version=? where id=? and version=?";
setStatusSql = "update " + taskTable + " set status=?,next_event_time=?,state_time=?,time_updated=?,version=? where id=? and version=?";
setNextEventTimeSql = "update " + taskTable
+ " set next_event_time=?,state_time=?,time_updated=?,version=? where id=? and version=? and status=?";
getStuckTasksSql = "select id,version,type,priority,status from " + taskTable + " where status=?"
+ " and next_event_time<? order by next_event_time limit ?";
prepareStuckOnProcessingTaskForResumingSql =
Expand Down Expand Up @@ -322,6 +325,14 @@ public boolean setStatus(UUID taskId, TaskStatus status, long version) {
return updatedCount == 1;
}

@Override
@Transactional(rollbackFor = Exception.class)
public boolean setNextEventTime(UUID taskId, ZonedDateTime nextEventTime, long version, String status) {
Timestamp now = Timestamp.from(Instant.now(TwContextClockHolder.getClock()));
int updatedCount = jdbcTemplate.update(setNextEventTimeSql, args(nextEventTime, now, now, version + 1, taskId, version, status));
return updatedCount == 1;
}

@Override
public GetStuckTasksResponse getStuckTasks(int batchSize, TaskStatus status) {
Timestamp now = Timestamp.from(ZonedDateTime.now(TwContextClockHolder.getClock()).toInstant());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public final class EntryPointsNames {
public static final String RESUME_TASK = "resumeTask";
public static final String ASYNC_HANDLE_SUCCESS = "asyncHandleSuccess";
public static final String ASYNC_HANDLE_FAIL = "asyncHandleFail";
public static final String RESCHEDULE_TASK = "rescheduleTask";

private EntryPointsNames() {
throw new AssertionError();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ public class CoreMetricsTemplate implements ICoreMetricsTemplate {
private static final String METRIC_TASKS_RETRIES_COUNT = METRIC_PREFIX + "tasks.retriesCount";
private static final String METRIC_TASKS_RESUMINGS_COUNT = METRIC_PREFIX + "tasks.resumingsCount";
private static final String METRIC_TASKS_MARKED_AS_FAILED_COUNT = METRIC_PREFIX + "tasks.markedAsFailedCount";
private static final String METRIC_TASKS_RESCHEDULED_COUNT = METRIC_PREFIX + "tasks.rescheduledCount";
private static final String METRIC_TASKS_FAILED_NEXT_EVENT_TIME_CHANGE_COUNT = METRIC_PREFIX + "tasks.failedNextEventTimeChangeCount";
private static final String METRIC_TASKS_ADDINGS_COUNT = METRIC_PREFIX + "task.addings.count";
private static final String METRIC_TASKS_SERVICE_IN_PROGRESS_TRIGGERINGS_COUNT = METRIC_PREFIX + "tasksService.inProgressTriggeringsCount";
private static final String METRIC_TASKS_SERVICE_ACTIVE_TRIGGERINGS_COUNT = METRIC_PREFIX + "tasksService.activeTriggeringsCount";
Expand Down Expand Up @@ -95,6 +97,8 @@ public class CoreMetricsTemplate implements ICoreMetricsTemplate {
private static final String TAG_PROCESSING_RESULT = "processingResult";
private static final String TAG_FROM_STATUS = "fromStatus";
private static final String TAG_TO_STATUS = "toStatus";
private static final String TAG_FROM_NEXT_EVENT_TIME = "fromNextEventTime";
private static final String TAG_TO_NEXT_EVENT_TIME = "toNextEventTime";
private static final String TAG_TASK_STATUS = "taskStatus";
private static final String TAG_BUCKET_ID = "bucketId";
private static final String TAG_SYNC = "sync";
Expand Down Expand Up @@ -161,6 +165,19 @@ public void registerFailedStatusChange(String taskType, String fromStatus, TaskS
.increment();
}

@Override
public void registerTaskRescheduled(String bucketId, String taskType) {
meterCache.counter(METRIC_TASKS_RESCHEDULED_COUNT, TagsSet.of(TAG_BUCKET_ID, resolveBucketId(bucketId), TAG_TASK_TYPE, taskType))
.increment();
}

@Override
public void registerFailedNextEventTimeChange(String taskType, ZonedDateTime fromNextEventTime, ZonedDateTime toNextEventTime) {
meterCache.counter(METRIC_TASKS_FAILED_NEXT_EVENT_TIME_CHANGE_COUNT, TagsSet.of(TAG_TASK_TYPE, taskType,
TAG_FROM_NEXT_EVENT_TIME, fromNextEventTime.toString(), TAG_TO_NEXT_EVENT_TIME, toNextEventTime.toString()))
.increment();
}

@Override
public void registerTaskGrabbingResponse(String bucketId, String taskType, int priority, ProcessTaskResponse processTaskResponse) {
meterCache.counter(METRIC_TASKS_TASK_GRABBING, TagsSet.of(TAG_TASK_TYPE, taskType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public interface ICoreMetricsTemplate {

void registerTaskMarkedAsFailed(String bucketId, String taskType);

void registerTaskRescheduled(String bucketId, String taskType);

void registerDuplicateTask(String taskType, boolean expected);

void registerScheduledTaskResuming(String taskType);
Expand All @@ -45,6 +47,8 @@ public interface ICoreMetricsTemplate {

void registerFailedStatusChange(String taskType, String fromStatus, TaskStatus toStatus);

void registerFailedNextEventTimeChange(String taskType, ZonedDateTime fromNextEventTime, ZonedDateTime toNextEventTime);

void registerTaskGrabbingResponse(String bucketId, String type, int priority, ProcessTaskResponse processTaskResponse);

void debugPriorityQueueCheck(String bucketId, int priority);
Expand Down

0 comments on commit 69aa3b8

Please sign in to comment.