From 54349c25f1f7c755e4c67cd5ff25c22664683dcd Mon Sep 17 00:00:00 2001 From: psyGamer Date: Sun, 15 Dec 2024 14:50:44 +0100 Subject: [PATCH] feat: Add EmitStaticDelegate helper method --- .../Source/EverestInterop/EntityDataHelper.cs | 40 ++--- .../Source/Utils/HookHelper.cs | 142 +++++++++++++++++- 2 files changed, 160 insertions(+), 22 deletions(-) diff --git a/CelesteTAS-EverestInterop/Source/EverestInterop/EntityDataHelper.cs b/CelesteTAS-EverestInterop/Source/EverestInterop/EntityDataHelper.cs index 5c297220..3d280b83 100644 --- a/CelesteTAS-EverestInterop/Source/EverestInterop/EntityDataHelper.cs +++ b/CelesteTAS-EverestInterop/Source/EverestInterop/EntityDataHelper.cs @@ -227,31 +227,31 @@ private static void StrawberrySeedOnCtor(On.Celeste.StrawberrySeed.orig_ctor ori private static void ModSpawnEntity(ILContext il) { ILCursor cursor = new(il); - if (cursor.TryGotoNext( - i => i.OpCode == OpCodes.Callvirt && i.Operand.ToString() == "System.Void Monocle.Scene::Add(Monocle.Entity)")) { - cursor.Emit(OpCodes.Dup).Emit(OpCodes.Ldarg_0); + if (cursor.TryGotoNext(ins => ins.OpCode == OpCodes.Callvirt && ins.Operand.ToString() == "System.Void Monocle.Scene::Add(Monocle.Entity)")) { + cursor.EmitDup(); + cursor.EmitLdarg0(); + + // TODO: Better match if (il.ToString().Contains("ldfld Celeste.SeekerStatue Celeste.SeekerStatue/<>c__DisplayClass3_0::<>4__this") && ModUtils.VanillaAssembly.GetType("Celeste.SeekerStatue+<>c__DisplayClass3_0")?.GetFieldInfo("<>4__this") is { } seekerStatue ) { - cursor.Emit(OpCodes.Ldfld, seekerStatue); - } - - cursor.EmitDelegate>(SetCustomEntityData); - } - } - - private static void SetCustomEntityData(Entity spawnedEntity, Entity entity) { - if (entity.GetEntityData() is { } entityData) { - EntityData clonedEntityData = entityData.ShallowClone(); - if (spawnedEntity is FireBall fireBall) { - clonedEntityData.ID = clonedEntityData.ID * -100 - fireBall.index; - } else if (entity is CS03_OshiroRooftop) { - clonedEntityData.ID = 2; - } else { - clonedEntityData.ID *= -1; + cursor.EmitLdfld(seekerStatue); } - spawnedEntity.SetEntityData(clonedEntityData); + cursor.EmitStaticDelegate("SetCustomEntityData", static (Entity spawnedEntity, Entity entity) => { + if (entity.GetEntityData() is { } entityData) { + EntityData clonedEntityData = entityData.ShallowClone(); + if (spawnedEntity is FireBall fireBall) { + clonedEntityData.ID = clonedEntityData.ID * -100 - fireBall.index; + } else if (entity is CS03_OshiroRooftop) { + clonedEntityData.ID = 2; + } else { + clonedEntityData.ID *= -1; + } + + spawnedEntity.SetEntityData(clonedEntityData); + } + }); } } diff --git a/CelesteTAS-EverestInterop/Source/Utils/HookHelper.cs b/CelesteTAS-EverestInterop/Source/Utils/HookHelper.cs index f3b7881f..f617fbe5 100644 --- a/CelesteTAS-EverestInterop/Source/Utils/HookHelper.cs +++ b/CelesteTAS-EverestInterop/Source/Utils/HookHelper.cs @@ -1,12 +1,18 @@ -using System; +using Celeste; +using Celeste.Mod; +using System; using System.Collections.Generic; using System.Diagnostics; using System.Numerics; using System.Reflection; using JetBrains.Annotations; +using Mono.Cecil; using Mono.Cecil.Cil; using MonoMod.Cil; using MonoMod.RuntimeDetour; +using MonoMod.Utils; +using System.Linq; +using System.Runtime.Loader; using TAS.Module; namespace TAS.Utils; @@ -221,4 +227,136 @@ public static void ReturnZeroMethod(Type conditionType, string conditionMethodNa } } } -} \ No newline at end of file + + /// Emits a call to a static delegate function. + /// Accessing captures is not allowed + public static void EmitStaticDelegate(this ILCursor cursor, T cb) where T : Delegate + => cursor.EmitStaticDelegate("Delegate", cb); + + /// Emits a call to a static delegate function. + /// Accessing captures is not allowed + public static void EmitStaticDelegate(this ILCursor cursor, string methodName, T cb) where T : Delegate { + // Simple static method group + if (cb.GetInvocationList().Length == 1 && cb.Target == null) { + cursor.EmitCall(cb.Method); + return; + } + + var methodDef = cb.Method.ResolveDefinition(); + + // Extract hook name from delegate + string hookName = cb.Method.Name.Split('>')[0][1..]; + string name = $"{hookName}_{methodName}"; + + var parameters = cb.Method.GetParameters(); + + var dynamicMethod = new DynamicMethodDefinition(name, + cb.Method.ReturnType, + parameters + .Select(p => p.ParameterType) + .ToArray()); + dynamicMethod.Definition.Body = methodDef.Body; + for (int i = 0; i < dynamicMethod.Definition.Parameters.Count; i++) { + dynamicMethod.Definition.Parameters[i].Name = parameters[i].Name; + } + + // Shift over arguments, since "this" was removed + var processor = dynamicMethod.GetILProcessor(); + foreach (var instr in processor.Body.Instructions) { + if (!instr.MatchLdarg(out int index)) { + continue; + } + + switch (index) { + case 0: + throw new Exception("Using captured variables inside a static delegate is not allowed"); + + case 1: + instr.OpCode = OpCodes.Ldarg_0; + break; + case 2: + instr.OpCode = OpCodes.Ldarg_1; + break; + case 3: + instr.OpCode = OpCodes.Ldarg_2; + break; + case 4: + instr.OpCode = OpCodes.Ldarg_3; + break; + + default: + instr.OpCode = OpCodes.Ldarg; + instr.Operand = index - 1; + break; + } + } + + var targetMethod = dynamicMethod.Generate(); + var targetReference = cursor.Context.Import(targetMethod); + targetReference.Name = name; + targetReference.DeclaringType = cb.Method.DeclaringType?.DeclaringType.ResolveDefinition(); + targetReference.ReturnType = dynamicMethod.Definition.ReturnType; + targetReference.Parameters.AddRange(dynamicMethod.Definition.Parameters); + + cursor.EmitCall(targetReference); + } + + /// Resolves the TypeDefinition of a runtime TypeInfo + public static TypeDefinition ResolveDefinition(this Type type) { + var asm = type.Assembly; + var asmName = type.Assembly.GetName(); + + // Find assembly path + string asmPath; + if (AssemblyLoadContext.GetLoadContext(asm) is EverestModuleAssemblyContext asmCtx) { + asmPath = Everest.Relinker.GetCachedPath(asmCtx.ModuleMeta, asmName.Name); + } else { + asmPath = asm.Location; + } + + var asmDef = AssemblyDefinition.ReadAssembly(asmPath, new ReaderParameters { ReadSymbols = false }); + var typeDef = asmDef.MainModule.GetType(type.FullName, runtimeName: true).Resolve(); + + return typeDef; + } + + /// Resolves the MethodDefinition of a runtime MethodBase + public static MethodDefinition ResolveDefinition(this MethodBase method) { + var asm = method.DeclaringType!.Assembly; + var asmName = method.DeclaringType!.Assembly.GetName(); + + // Find assembly path + string asmPath; + if (AssemblyLoadContext.GetLoadContext(asm) is EverestModuleAssemblyContext asmCtx) { + asmPath = Everest.Relinker.GetCachedPath(asmCtx.ModuleMeta, asmName.Name); + } else { + asmPath = asm.Location; + } + + var asmDef = AssemblyDefinition.ReadAssembly(asmPath, new ReaderParameters { ReadSymbols = false }); + var typeDef = asmDef.MainModule.GetType(method.DeclaringType!.FullName, runtimeName: true).Resolve(); + var methodDef = typeDef.Methods.Single(m => { + if (method.Name != m.Name) { + return false; + } + + var runtimeParams = method.GetParameters(); + if (runtimeParams.Length != m.Parameters.Count) { + return false; + } + + for (int i = 0; i < runtimeParams.Length; i++) { + var runtimeParam = runtimeParams[i]; + var asmParam = m.Parameters[i]; + + if (runtimeParam.ParameterType.FullName != asmParam.ParameterType.FullName) { + return false; + } + } + + return true; + }); + + return methodDef; + } +}