Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#11768] Migrate TimerTaskDecorator to TaskDecorator #11769

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

package com.navercorp.pinpoint.common.server.task;

import java.util.TimerTask;
import org.springframework.core.task.TaskDecorator;

/**
* @author HyunGil Jeong
*/
public interface TimerTaskDecorator {
public interface TaskDecoratorFactory {

TimerTask decorate(TimerTask timerTask);
TaskDecorator createDecorator();
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
package com.navercorp.pinpoint.web.realtime.activethread.count.service;

import com.navercorp.pinpoint.common.server.cluster.ClusterKey;
import com.navercorp.pinpoint.common.server.task.TimerTaskDecorator;
import com.navercorp.pinpoint.common.server.task.TimerTaskDecoratorFactory;
import com.navercorp.pinpoint.common.server.task.TaskDecoratorFactory;
import com.navercorp.pinpoint.realtime.dto.ATCSupply;
import com.navercorp.pinpoint.web.realtime.activethread.count.dao.ActiveThreadCountDao;
import com.navercorp.pinpoint.web.realtime.activethread.count.dto.ActiveThreadCountResponse;
import com.navercorp.pinpoint.web.realtime.service.AgentLookupService;
import org.springframework.core.task.TaskDecorator;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
Expand All @@ -32,7 +32,6 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TimerTask;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
Expand All @@ -47,30 +46,30 @@ public class ActiveThreadCountServiceImpl implements ActiveThreadCountService {

private final ActiveThreadCountDao atcDao;
private final AgentLookupService agentLookupService;
private final TimerTaskDecoratorFactory timerTaskDecoratorFactory;
private final TaskDecoratorFactory taskDecoratorFactory;
private final Scheduler scheduler;
private final Duration emitPeriod;
private final Duration updatePeriod;

public ActiveThreadCountServiceImpl(
ActiveThreadCountDao atcDao,
AgentLookupService agentLookupService,
TimerTaskDecoratorFactory timerTaskDecoratorFactory,
TaskDecoratorFactory taskDecoratorFactory,
ScheduledExecutorService scheduledExecutor,
Duration emitPeriod,
Duration updatePeriod
) {
this.atcDao = Objects.requireNonNull(atcDao, "atcDao");
this.agentLookupService = Objects.requireNonNull(agentLookupService, "agentLookupService");
this.timerTaskDecoratorFactory = Objects.requireNonNull(timerTaskDecoratorFactory, "timerTaskDecoratorFactory");
this.taskDecoratorFactory = Objects.requireNonNull(taskDecoratorFactory, "taskDecoratorFactory");
this.scheduler = Schedulers.fromExecutorService(Objects.requireNonNull(scheduledExecutor, "scheduledExecutor"));
this.emitPeriod = Objects.requireNonNull(emitPeriod, "emitPeriod");
this.updatePeriod = Objects.requireNonNull(updatePeriod, "updatePeriod");
}

@Override
public Flux<ActiveThreadCountResponse> getResponses(String applicationName) {
TimerTaskDecorator taskDecorator = timerTaskDecoratorFactory.createTimerTaskDecorator();
TaskDecorator taskDecorator = taskDecoratorFactory.createDecorator();
SupplyCollector collector = new SupplyCollector(applicationName, emitPeriod.toMillis() * 2);

Map<ClusterKey, Disposable> disposableMap = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -105,9 +104,9 @@ private static ClusterKey extractKey(ATCSupply supply) {
return new ClusterKey(supply.getApplicationName(), supply.getAgentId(), supply.getStartTimestamp());
}

private Mono<List<ClusterKey>> getAgents(TimerTaskDecorator taskDecorator, String applicationName) {
private Mono<List<ClusterKey>> getAgents(TaskDecorator taskDecorator, String applicationName) {
return Mono.create(sink -> {
taskDecorator.decorate(new TimerTask() {
taskDecorator.decorate(new Runnable() {
@Override
public void run() {
sink.success(agentLookupService.getRecentAgents(applicationName));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.navercorp.pinpoint.web.realtime;

import com.navercorp.pinpoint.common.server.task.TimerTaskDecoratorFactory;
import com.navercorp.pinpoint.common.server.task.TaskDecoratorFactory;
import com.navercorp.pinpoint.web.frontend.export.FrontendConfigExporter;
import com.navercorp.pinpoint.web.realtime.activethread.count.dao.ActiveThreadCountDao;
import com.navercorp.pinpoint.web.realtime.activethread.count.service.ActiveThreadCountService;
Expand All @@ -30,7 +30,7 @@
import com.navercorp.pinpoint.web.service.ApplicationAgentListService;
import com.navercorp.pinpoint.web.service.EchoService;
import com.navercorp.pinpoint.web.websocket.PinpointWebSocketHandler;
import com.navercorp.pinpoint.web.websocket.PinpointWebSocketTimerTaskDecoratorFactory;
import com.navercorp.pinpoint.web.websocket.WebSocketTaskDecoratorFactory;
import com.navercorp.pinpoint.web.websocket.message.PinpointWebSocketMessageConverter;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
Expand Down Expand Up @@ -87,13 +87,13 @@
AgentLookupService agentLookupService,
@Qualifier("pubSubATCSessionScheduledExecutor") ScheduledExecutorService scheduledExecutor,
ActiveThreadCountService.ATCPeriods atcPeriods,
@Autowired(required = false) @Nullable TimerTaskDecoratorFactory timerTaskDecoratorFactory
@Autowired(required = false) @Nullable TaskDecoratorFactory taskDecoratorFactory
) {
return new ActiveThreadCountServiceImpl(
atcDao,
agentLookupService,
Objects.requireNonNullElseGet(timerTaskDecoratorFactory,
PinpointWebSocketTimerTaskDecoratorFactory::new),
Objects.requireNonNullElseGet(taskDecoratorFactory,

Check warning on line 95 in web/src/main/java/com/navercorp/pinpoint/web/realtime/RealtimeConfig.java

View check run for this annotation

Codecov / codecov/patch

web/src/main/java/com/navercorp/pinpoint/web/realtime/RealtimeConfig.java#L95

Added line #L95 was not covered by tests
WebSocketTaskDecoratorFactory::new),
scheduledExecutor,
atcPeriods.getPeriodEmit(),
atcPeriods.getPeriodUpdate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,28 @@

package com.navercorp.pinpoint.web.websocket;

import com.navercorp.pinpoint.common.server.task.TimerTaskDecorator;
import com.navercorp.pinpoint.common.server.task.TimerTaskDecoratorFactory;
import com.navercorp.pinpoint.common.server.task.TaskDecoratorFactory;
import com.navercorp.pinpoint.web.util.SecurityContextUtils;
import org.springframework.core.task.TaskDecorator;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;

import java.util.TimerTask;

/**
* @author HyunGil Jeong
*/
public class PinpointWebSocketTimerTaskDecoratorFactory implements TimerTaskDecoratorFactory {
public class WebSocketTaskDecoratorFactory implements TaskDecoratorFactory {

@Override
public TimerTaskDecorator createTimerTaskDecorator() {
return new SecurityContextPreservingTimerTaskDecorator();
public TaskDecorator createDecorator() {
return new SecurityContextPreservingTaskDecorator();
}

private static class SecurityContextPreservingTimerTaskDecorator implements TimerTaskDecorator {
private static class SecurityContextPreservingTaskDecorator implements TaskDecorator {

private final SecurityContext securityContext;

private SecurityContextPreservingTimerTaskDecorator() {
private SecurityContextPreservingTaskDecorator() {
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
final Authentication authentication = SecurityContextUtils.getAuthentication();
if (authentication != null) {
Expand All @@ -49,14 +47,14 @@ private SecurityContextPreservingTimerTaskDecorator() {
}

@Override
public TimerTask decorate(TimerTask timerTask) {
return new TimerTask() {
public Runnable decorate(Runnable task) {
return new Runnable() {
@Override
public void run() {
SecurityContext previousSecurityContext = SecurityContextHolder.getContext();
try {
SecurityContextHolder.setContext(securityContext);
timerTask.run();
task.run();
} finally {
SecurityContextHolder.setContext(previousSecurityContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.navercorp.pinpoint.web.websocket;

import com.navercorp.pinpoint.common.server.task.TimerTaskDecoratorFactory;
import com.navercorp.pinpoint.common.server.task.TaskDecoratorFactory;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.security.authentication.TestingAuthenticationToken;
Expand All @@ -33,38 +33,43 @@
/**
* @author HyunGil Jeong
*/
public class PinpointWebSocketTimerTaskDecoratorTest {
public class WebSocketTaskDecoratorFactoryTest {

private final TimerTaskDecoratorFactory timerTaskDecoratorFactory = new PinpointWebSocketTimerTaskDecoratorFactory();
private final TaskDecoratorFactory taskDecoratorFactory = new WebSocketTaskDecoratorFactory();

@Test
public void testAuthenticationPropagation() {
final int numThreads = 3;
final Authentication[] authentications = new Authentication[numThreads];
for (int i = 0; i < authentications.length; i++) {
final String principal = "principal" + i;
final String credential = "credential" + i;
authentications[i] = new TestingAuthenticationToken(principal, credential);
}
final List<Authentication> authentications = sampleAuthentications(numThreads);
List<CompletableFuture<Authentication>> result = new ArrayList<>();
for (Authentication authentication : authentications) {
CompletableFuture<Authentication> future = CompletableFuture.supplyAsync(() -> {
SecurityContext securityContext = new SecurityContextImpl();
securityContext.setAuthentication(authentication);
SecurityContextHolder.setContext(securityContext);
TestTimerTask run = new TestTimerTask();
TimerTask timerTask = timerTaskDecoratorFactory.createTimerTaskDecorator().decorate(run);
timerTask.run();
Runnable task = taskDecoratorFactory.createDecorator().decorate(run);
task.run();
return run.result();
});
result.add(future);
}

for (int i = 0; i < authentications.length; i++) {
Authentication expected = authentications[i];
Authentication actual = result.get(i).join();
Assertions.assertEquals(expected, actual);
int i = 0;
for (Authentication authentication : authentications) {
Authentication actual = result.get(i++).join();
Assertions.assertEquals(authentication, actual);
}
}

private List<Authentication> sampleAuthentications(int numThreads) {
final List<Authentication> result = new ArrayList<>(numThreads);
for (int i = 0; i < numThreads; i++) {
final String principal = "principal" + i;
final String credential = "credential" + i;
result.add(new TestingAuthenticationToken(principal, credential));
}
return result;
}

private static class TestTimerTask extends TimerTask {
Expand Down