Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client gametest threading tweaks #4304

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
*
* <p>Client gametests run on the client gametest thread. Use the functions inside
* {@link net.fabricmc.fabric.api.client.gametest.v1.ClientGameTestContext ClientGameTestContext} and other test helper
* classes to run code on the correct thread. The game remains paused unless you explicitly unpause it using various
* waiting functions such as
* classes to run code on the correct thread. Exceptions are transparently rethrown on the test thread, and their stack
* traces are mutated to include the async stack trace, to make them easy to track. You can disable this behavior by
* setting the {@code fabric.client.gametest.disableJoinAsyncStackTraces} system property.
*
* <p>The game remains paused unless you explicitly unpause it using various waiting functions such as
* {@link net.fabricmc.fabric.api.client.gametest.v1.ClientGameTestContext#waitTick() ClientGameTestContext.waitTick()}.
*
* <p>A few changes have been made to how the vanilla game threads run, to make tests more reproducible. Notably, there
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,11 @@ public static void start() {
for (FabricClientGameTest gameTest : gameTests) {
context.restoreDefaultGameOptions();

try {
gameTest.runTest(context);
} finally {
context.getInput().clearKeysDown();
checkFinalGameTestState(context, gameTest.getClass().getName());
}
}
gameTest.runTest(context);

context.clickScreenButton("menu.quit");
context.getInput().clearKeysDown();
checkFinalGameTestState(context, gameTest.getClass().getName());
}
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.minecraft.client.MinecraftClient;

/**
* <h1>Implementation notes</h1>
*
Expand Down Expand Up @@ -68,13 +70,19 @@ private ThreadingImpl() {

private static final Logger LOGGER = LoggerFactory.getLogger("fabric-client-gametest-api-v1");

private static final boolean DISABLE_JOIN_ASYNC_STACK_TRACES = System.getProperty("fabric.client.gametest.disableJoinAsyncStackTraces") != null;
private static final String THREAD_IMPL_CLASS_NAME = ThreadingImpl.class.getName();
private static final String TASK_ON_THIS_THREAD_METHOD_NAME = "runTaskOnThisThread";
private static final String TASK_ON_OTHER_THREAD_METHOD_NAME = "runTaskOnOtherThread";

public static final int PHASE_TICK = 0;
public static final int PHASE_SERVER_TASKS = 1;
public static final int PHASE_CLIENT_TASKS = 2;
public static final int PHASE_TEST = 3;
private static final int PHASE_MASK = 3;

public static final Phaser PHASER = new Phaser();
private static volatile boolean enablePhases = true;

public static volatile boolean isClientRunning = false;
public static volatile boolean clientCanAcceptTasks = false;
Expand All @@ -87,71 +95,95 @@ private ThreadingImpl() {
@Nullable
public static Thread testThread = null;
public static final Semaphore TEST_SEMAPHORE = new Semaphore(0);
@Nullable
public static Throwable testFailureException = null;

@Nullable
public static Runnable taskToRun = null;

private static volatile boolean gameCrashed = false;

public static void enterPhase(int phase) {
while ((PHASER.getPhase() & PHASE_MASK) != phase) {
while (enablePhases && (PHASER.getPhase() & PHASE_MASK) != phase) {
PHASER.arriveAndAwaitAdvance();
}

PHASER.arriveAndAwaitAdvance();
if (enablePhases) {
PHASER.arriveAndAwaitAdvance();
}
}

public static void setGameCrashed() {
enablePhases = false;
gameCrashed = true;
}

public static void runTestThread(Runnable test) {
public static void runTestThread(Runnable testRunner) {
Preconditions.checkState(testThread == null, "There is already a test thread running");

testThread = new Thread(() -> {
PHASER.register();
enterPhase(PHASE_TEST);

try {
test.run();
testRunner.run();
} catch (Throwable e) {
LOGGER.error("Failed to run client gametests", e);
System.exit(1);
testFailureException = e;
} finally {
PHASER.arriveAndDeregister();

if (clientCanAcceptTasks) {
CLIENT_SEMAPHORE.release();
runOnClient(() -> MinecraftClient.getInstance().scheduleStop());
}

if (serverCanAcceptTasks) {
SERVER_SEMAPHORE.release();
if (testFailureException != null) {
// Log this now in case the client has stopped or is otherwise unable to rethrow our exception
LOGGER.error("Client gametests failed with an exception", testFailureException);
}

testThread = null;
deregisterTestThread();
}
});
testThread.setName("Test thread");
testThread.setDaemon(true);
testThread.start();
}

private static void deregisterTestThread() {
testThread = null;
enablePhases = false;
PHASER.arriveAndDeregister();

if (clientCanAcceptTasks) {
CLIENT_SEMAPHORE.release();
}

if (serverCanAcceptTasks) {
SERVER_SEMAPHORE.release();
}
}

public static void checkOnGametestThread(String methodName) {
Preconditions.checkState(Thread.currentThread() == testThread, "%s can only be called from the client gametest thread", methodName);
}

@SuppressWarnings("unchecked")
public static <E extends Throwable> void runOnClient(FailableRunnable<E> action) throws E {
Preconditions.checkNotNull(action, "action");
checkOnGametestThread("runOnClient");
Preconditions.checkState(clientCanAcceptTasks, "runOnClient called when no client is running");
runTaskOnOtherThread(action, CLIENT_SEMAPHORE);
}

public static <E extends Throwable> void runOnServer(FailableRunnable<E> action) throws E {
Preconditions.checkNotNull(action, "action");
checkOnGametestThread("runOnServer");
Preconditions.checkState(serverCanAcceptTasks, "runOnServer called when no server is running");
runTaskOnOtherThread(action, SERVER_SEMAPHORE);
}

private static <E extends Throwable> void runTaskOnOtherThread(FailableRunnable<E> action, Semaphore clientOrServerSemaphore) throws E {
MutableObject<E> thrown = new MutableObject<>();
taskToRun = () -> {
try {
action.run();
} catch (Throwable e) {
thrown.setValue((E) e);
} finally {
taskToRun = null;
TEST_SEMAPHORE.release();
}
};
taskToRun = () -> runTaskOnThisThread(action, thrown);

CLIENT_SEMAPHORE.release();
clientOrServerSemaphore.release();

try {
TEST_SEMAPHORE.acquire();
Expand All @@ -160,39 +192,73 @@ public static <E extends Throwable> void runOnClient(FailableRunnable<E> action)
}

if (thrown.getValue() != null) {
joinAsyncStackTrace(thrown.getValue());
throw thrown.getValue();
}
}

@SuppressWarnings("unchecked")
public static <E extends Throwable> void runOnServer(FailableRunnable<E> action) throws E {
Preconditions.checkNotNull(action, "action");
checkOnGametestThread("runOnServer");
Preconditions.checkState(serverCanAcceptTasks, "runOnServer called when no server is running");
private static <E extends Throwable> void runTaskOnThisThread(FailableRunnable<E> action, MutableObject<E> thrown) {
try {
action.run();
} catch (Throwable e) {
thrown.setValue((E) e);
} finally {
taskToRun = null;
TEST_SEMAPHORE.release();
}
}

MutableObject<E> thrown = new MutableObject<>();
taskToRun = () -> {
try {
action.run();
} catch (Throwable e) {
thrown.setValue((E) e);
} finally {
taskToRun = null;
TEST_SEMAPHORE.release();
private static void joinAsyncStackTrace(Throwable e) {
if (DISABLE_JOIN_ASYNC_STACK_TRACES) {
return;
}

// find the end of the relevant part of the stack trace on the other thread
StackTraceElement[] otherThreadStackTrace = e.getStackTrace();

if (otherThreadStackTrace == null) {
return;
}

int otherThreadIndex = otherThreadStackTrace.length - 1;

for (; otherThreadIndex >= 0; otherThreadIndex--) {
StackTraceElement element = otherThreadStackTrace[otherThreadIndex];

if (THREAD_IMPL_CLASS_NAME.equals(element.getClassName()) && TASK_ON_THIS_THREAD_METHOD_NAME.equals(element.getMethodName())) {
break;
}
};
}

SERVER_SEMAPHORE.release();
if (otherThreadIndex == -1) {
// couldn't find stack trace element
return;
}

try {
TEST_SEMAPHORE.acquire();
} catch (InterruptedException e) {
throw new RuntimeException(e);
// find the start of the relevant part of the stack trace on the test thread
StackTraceElement[] thisThreadStackTrace = Thread.currentThread().getStackTrace();
int thisThreadIndex = 0;

for (; thisThreadIndex < thisThreadStackTrace.length; thisThreadIndex++) {
StackTraceElement element = thisThreadStackTrace[thisThreadIndex];

if (THREAD_IMPL_CLASS_NAME.equals(element.getClassName()) && TASK_ON_OTHER_THREAD_METHOD_NAME.equals(element.getMethodName())) {
break;
}
}

if (thrown.getValue() != null) {
throw thrown.getValue();
if (thisThreadIndex == thisThreadStackTrace.length) {
// couldn't find stack trace element
return;
}

// join the stack traces
StackTraceElement[] joinedStackTrace = new StackTraceElement[(otherThreadIndex + 1) + 1 + (thisThreadStackTrace.length - thisThreadIndex)];
System.arraycopy(otherThreadStackTrace, 0, joinedStackTrace, 0, otherThreadIndex + 1);
joinedStackTrace[otherThreadIndex + 1] = new StackTraceElement("Async Stack Trace", ".", null, 1);
System.arraycopy(thisThreadStackTrace, thisThreadIndex, joinedStackTrace, otherThreadIndex + 2, thisThreadStackTrace.length - thisThreadIndex);
e.setStackTrace(joinedStackTrace);
}

public static void runTick() {
Expand All @@ -207,5 +273,17 @@ public static void runTick() {
}

enterPhase(PHASE_TEST);

// Check if the game has crashed during this tick. If so, don't do any more work in the test
if (gameCrashed) {
deregisterTestThread();

try {
// wait until game is closed
new Semaphore(0).acquire();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class MinecraftClientMixin {
private Overlay overlay;

@WrapMethod(method = "run")
private void onRun(Operation<Void> original) {
private void onRun(Operation<Void> original) throws Throwable {
if (ThreadingImpl.isClientRunning) {
throw new IllegalStateException("Client is already running");
}
Expand All @@ -61,12 +61,21 @@ private void onRun(Operation<Void> original) {
try {
original.call();
} finally {
ThreadingImpl.clientCanAcceptTasks = false;
ThreadingImpl.PHASER.arriveAndDeregister();
ThreadingImpl.isClientRunning = false;
deregisterClient();

if (ThreadingImpl.testFailureException != null) {
throw ThreadingImpl.testFailureException;
}
}
}

@Inject(method = "cleanUpAfterCrash", at = @At("HEAD"))
private void deregisterAfterCrash(CallbackInfo ci) {
// Deregister a bit earlier than normal to allow for the integrated server to stop without waiting for the client
ThreadingImpl.setGameCrashed();
deregisterClient();
}

@Inject(method = "tick", at = @At("HEAD"))
private void onTick(CallbackInfo ci) {
if (!startedClientGametests && overlay == null) {
Expand Down Expand Up @@ -152,4 +161,13 @@ private static void checkThreadOnGetInstance(CallbackInfoReturnable<MinecraftCli
"MinecraftClient.getInstance() cannot be called from the gametest thread. Try using ClientGameTestContext.runOnClient or ClientGameTestContext.computeOnClient"
);
}

@Unique
private static void deregisterClient() {
if (ThreadingImpl.isClientRunning) {
ThreadingImpl.clientCanAcceptTasks = false;
ThreadingImpl.PHASER.arriveAndDeregister();
ThreadingImpl.isClientRunning = false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import com.llamalad7.mixinextras.injector.wrapmethod.WrapMethod;
import com.llamalad7.mixinextras.injector.wrapoperation.Operation;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Unique;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfo;

import net.minecraft.client.MinecraftClient;
import net.minecraft.server.MinecraftServer;

import net.fabricmc.fabric.impl.client.gametest.ThreadingImpl;
Expand All @@ -41,12 +43,21 @@ private void onRunServer(Operation<Void> original) {
try {
original.call();
} finally {
ThreadingImpl.serverCanAcceptTasks = false;
ThreadingImpl.PHASER.arriveAndDeregister();
ThreadingImpl.isServerRunning = false;
deregisterServer();
}
}

@Inject(method = "runServer", at = @At(value = "INVOKE", target = "Lnet/minecraft/server/MinecraftServer;setCrashReport(Lnet/minecraft/util/crash/CrashReport;)V", shift = At.Shift.AFTER))
protected void onCrash(CallbackInfo ci) {
if (ThreadingImpl.testFailureException == null) {
ThreadingImpl.testFailureException = new Throwable("The server crashed");
}

MinecraftClient.getInstance().scheduleStop();
ThreadingImpl.setGameCrashed();
deregisterServer();
}

@Inject(method = "runServer", at = @At(value = "INVOKE", target = "Lnet/minecraft/server/MinecraftServer;runTasksTillTickEnd()V"))
private void preRunTasks(CallbackInfo ci) {
ThreadingImpl.enterPhase(ThreadingImpl.PHASE_SERVER_TASKS);
Expand Down Expand Up @@ -78,4 +89,11 @@ private void postRunTasks(CallbackInfo ci) {

ThreadingImpl.enterPhase(ThreadingImpl.PHASE_TICK);
}

@Unique
private void deregisterServer() {
ThreadingImpl.serverCanAcceptTasks = false;
ThreadingImpl.PHASER.arriveAndDeregister();
ThreadingImpl.isServerRunning = false;
}
}
Loading