diff --git a/Harmony/Internal/CodeTranspiler.cs b/Harmony/Internal/CodeTranspiler.cs index 778e8796..d74844ea 100644 --- a/Harmony/Internal/CodeTranspiler.cs +++ b/Harmony/Internal/CodeTranspiler.cs @@ -194,7 +194,7 @@ internal static IEnumerable ConvertToGeneralInstructions(MethodInfo transpiler, { var type = transpiler.GetParameters() .Select(p => p.ParameterType) - .FirstOrDefault(t => IsCodeInstructionsParameter(t)); + .FirstOrDefault(IsCodeInstructionsParameter); if (type == typeof(IEnumerable)) { unassignedValues = null; diff --git a/Harmony/Internal/HarmonySharedState.cs b/Harmony/Internal/HarmonySharedState.cs index 170701f8..a87c6a4a 100644 --- a/Harmony/Internal/HarmonySharedState.cs +++ b/Harmony/Internal/HarmonySharedState.cs @@ -35,10 +35,13 @@ internal static class HarmonySharedState const string name = "HarmonySharedState"; internal const int internalVersion = 102; // bump this if the layout of the HarmonySharedState type changes - // state/originals/methodStarts are set to instances stored in the global dynamic types static fields with the same name + // state/originals/originalsMono are set to instances stored in the global dynamic types static fields with the same name static readonly Dictionary state; static readonly Dictionary originals; - + static readonly Dictionary originalsMono; + + static readonly AccessTools.FieldRef methodAddressRef; + internal static readonly int actualVersion; static HarmonySharedState() @@ -46,6 +49,10 @@ static HarmonySharedState() // create singleton type var type = GetOrCreateSharedStateType(); + // this field is useed to find methods from stackframes in Mono + if (AccessTools.IsMonoRuntime && AccessTools.Field(typeof(StackFrame), "methodAddress") is FieldInfo field) + methodAddressRef = AccessTools.FieldRefAccess(field); + // copy 'actualVersion' over to our fields var versionField = type.GetField("version"); if ((int)versionField.GetValue(null) == 0) @@ -62,6 +69,11 @@ static HarmonySharedState() if (originalsField != null && originalsField.GetValue(null) is null) originalsField.SetValue(null, new Dictionary()); + // get or initialize global 'originalsMono' field + var originalsMonoField = type.GetField("originalsMono"); + if (originalsMonoField != null && originalsMonoField.GetValue(null) is null) + originalsMonoField.SetValue(null, new Dictionary()); + // copy 'state' over to our fields state = (Dictionary)stateField.GetValue(null); @@ -69,6 +81,11 @@ static HarmonySharedState() originals = []; if (originalsField != null) // may not exist in older versions originals = (Dictionary)originalsField.GetValue(null); + + // copy 'originalsMono' over to our fields + originalsMono = []; + if (originalsMonoField != null) // may not exist in older versions + originalsMono = (Dictionary)originalsMonoField.GetValue(null); } // creates a dynamic 'global' type if it does not exist @@ -94,6 +111,12 @@ static Type GetOrCreateSharedStateType() module.ImportReference(typeof(Dictionary)) )); + typedef.Fields.Add(new FieldDefinition( + "originalsMono", + Mono.Cecil.FieldAttributes.Public | Mono.Cecil.FieldAttributes.Static, + module.ImportReference(typeof(Dictionary)) + )); + typedef.Fields.Add(new FieldDefinition( "version", Mono.Cecil.FieldAttributes.Public | Mono.Cecil.FieldAttributes.Static, @@ -122,19 +145,49 @@ internal static void UpdatePatchInfo(MethodBase original, MethodInfo replacement { var bytes = patchInfo.Serialize(); lock (state) state[original] = bytes; - lock (originals) originals[replacement] = original; + lock (originals) originals[replacement.Identifiable()] = original; + if (AccessTools.IsMonoRuntime) + { + var methodAddress = (long)replacement.MethodHandle.GetFunctionPointer(); + lock (originalsMono) originalsMono[methodAddress] = [original, replacement]; + } } - internal static MethodBase GetOriginal(MethodInfo replacement) + // With mono, useReplacement is used to either return the original or the replacement + // On .NET, useReplacement is ignored and the original is always returned + internal static MethodBase GetRealMethod(MethodInfo method, bool useReplacement) { - lock (originals) return originals.GetValueSafe(replacement); + var identifiableMethod = method.Identifiable(); + lock (originals) + if (originals.TryGetValue(identifiableMethod, out var original)) + return original; + + if (AccessTools.IsMonoRuntime) + { + var methodAddress = (long)method.MethodHandle.GetFunctionPointer(); + lock (originalsMono) + if (originalsMono.TryGetValue(methodAddress, out var info)) + return useReplacement ? info[1] : info[0]; + } + + return method; } - internal static MethodBase FindReplacement(StackFrame frame) + internal static MethodBase GetStackFrameMethod(StackFrame frame, bool useReplacement) { var method = frame.GetMethod() as MethodInfo; - if (method == null) return null; - return GetOriginal(method); + if (method != null) + return GetRealMethod(method, useReplacement); + + if (methodAddressRef != null) + { + var methodAddress = methodAddressRef(frame); + lock (originalsMono) + if (originalsMono.TryGetValue(methodAddress, out var info)) + return useReplacement ? info[1] : info[0]; + } + + return null; } } } diff --git a/Harmony/Internal/MethodCopier.cs b/Harmony/Internal/MethodCopier.cs index 90e603c3..e9c674a3 100644 --- a/Harmony/Internal/MethodCopier.cs +++ b/Harmony/Internal/MethodCopier.cs @@ -358,7 +358,7 @@ internal List FinalizeILCodes(Emitter emitter, List // pass2 - filter through all processors // var codeTranspiler = new CodeTranspiler(ilInstructions); - transpilers.Do(transpiler => codeTranspiler.Add(transpiler)); + transpilers.Do(codeTranspiler.Add); var codeInstructions = codeTranspiler.GetResult(generator, method); if (emitter is null) diff --git a/Harmony/Internal/MethodPatcher.cs b/Harmony/Internal/MethodPatcher.cs index 3c19924d..4b5356ec 100644 --- a/Harmony/Internal/MethodPatcher.cs +++ b/Harmony/Internal/MethodPatcher.cs @@ -101,7 +101,7 @@ internal MethodInfo CreateReplacement(out Dictionary final Label? skipOriginalLabel = null; LocalBuilder runOriginalVariable = null; - var prefixAffectsOriginal = prefixes.Any(fix => PrefixAffectsOriginal(fix)); + var prefixAffectsOriginal = prefixes.Any(PrefixAffectsOriginal); var anyFixHasRunOriginalVar = fixes.Any(fix => fix.GetParameters().Any(p => p.Name == RUN_ORIGINAL_VAR)); if (prefixAffectsOriginal || anyFixHasRunOriginalVar) { diff --git a/Harmony/Internal/PatchModels.cs b/Harmony/Internal/PatchModels.cs index ca97c70c..dadf9985 100644 --- a/Harmony/Internal/PatchModels.cs +++ b/Harmony/Internal/PatchModels.cs @@ -101,7 +101,7 @@ internal static AttributePatch Create(MethodInfo patch) var f_info = AccessTools.Field(attr.GetType(), nameof(HarmonyAttribute.info)); return f_info.GetValue(attr); }) - .Select(harmonyInfo => AccessTools.MakeDeepCopy(harmonyInfo)) + .Select(AccessTools.MakeDeepCopy) .ToList(); var info = HarmonyMethod.Merge(list); info.method = patch; diff --git a/Harmony/Internal/PatchTools.cs b/Harmony/Internal/PatchTools.cs index 2f5d92d7..82c55062 100644 --- a/Harmony/Internal/PatchTools.cs +++ b/Harmony/Internal/PatchTools.cs @@ -32,7 +32,7 @@ internal static void DetourMethod(MethodBase method, MethodBase replacement) static Assembly GetExecutingAssemblyReplacement() { var frames = new StackTrace().GetFrames(); - if (frames?.Skip(1).FirstOrDefault() is { } frame && Harmony.GetOriginalMethodFromStackframe(frame) is { } original) + if (frames?.Skip(1).FirstOrDefault() is { } frame && Harmony.GetMethodFromStackframe(frame) is { } original) return original.Module.Assembly; return Assembly.GetExecutingAssembly(); } @@ -78,7 +78,7 @@ internal static AssemblyBuilder DefineDynamicAssembly(string name) internal static List GetPatchMethods(Type type) { return AccessTools.GetDeclaredMethods(type) - .Select(method => AttributePatch.Create(method)) + .Select(AttributePatch.Create) .Where(attributePatch => attributePatch is not null) .ToList(); } diff --git a/Harmony/Public/Harmony.cs b/Harmony/Public/Harmony.cs index 65b8e9ed..d87909d7 100644 --- a/Harmony/Public/Harmony.cs +++ b/Harmony/Public/Harmony.cs @@ -1,4 +1,3 @@ -using MonoMod.Core.Platforms; using System; using System.Collections.Generic; using System.Diagnostics; @@ -227,7 +226,7 @@ public void Unpatch(MethodBase original, MethodInfo patch) public static bool HasAnyPatches(string harmonyID) { return GetAllPatchedMethods() - .Select(original => GetPatchInfo(original)) + .Select(GetPatchInfo) .Any(info => info.Owners.Contains(harmonyID)); } @@ -252,15 +251,13 @@ public IEnumerable GetPatchedMethods() public static IEnumerable GetAllPatchedMethods() => PatchProcessor.GetAllPatchedMethods(); /// Gets the original method from a given replacement method - /// A replacement method, for example from a stacktrace + /// A replacement method (patched original method) /// The original method/constructor or null if not found /// public static MethodBase GetOriginalMethod(MethodInfo replacement) { if (replacement == null) throw new ArgumentNullException(nameof(replacement)); - // The runtime can return several different MethodInfo's that point to the same method. Use the correct one - var identifiableReplacement = PlatformTriple.Current.GetIdentifiable(replacement) as MethodInfo; - return HarmonySharedState.GetOriginal(identifiableReplacement); + return HarmonySharedState.GetRealMethod(replacement, useReplacement: false); } /// Tries to get the method from a stackframe including dynamic replacement methods @@ -270,7 +267,7 @@ public static MethodBase GetOriginalMethod(MethodInfo replacement) public static MethodBase GetMethodFromStackframe(StackFrame frame) { if (frame == null) throw new ArgumentNullException(nameof(frame)); - return HarmonySharedState.FindReplacement(frame) ?? frame.GetMethod(); + return HarmonySharedState.GetStackFrameMethod(frame, useReplacement: true); } /// Gets the original method from the stackframe and uses original if method is a dynamic replacement @@ -278,16 +275,15 @@ public static MethodBase GetMethodFromStackframe(StackFrame frame) /// The original method from that stackframe public static MethodBase GetOriginalMethodFromStackframe(StackFrame frame) { - var member = GetMethodFromStackframe(frame); - if (member is MethodInfo methodInfo) - member = GetOriginalMethod(methodInfo) ?? member; - return member; + if (frame == null) throw new ArgumentNullException(nameof(frame)); + return HarmonySharedState.GetStackFrameMethod(frame, useReplacement: false); } /// Gets Harmony version for all active Harmony instances /// [out] The current Harmony version /// A dictionary containing assembly versions keyed by Harmony IDs /// - public static Dictionary VersionInfo(out Version currentVersion) => PatchProcessor.VersionInfo(out currentVersion); + public static Dictionary VersionInfo(out Version currentVersion) + => PatchProcessor.VersionInfo(out currentVersion); } } diff --git a/Harmony/Public/HarmonyMethod.cs b/Harmony/Public/HarmonyMethod.cs index e19d8cc0..379ae6db 100644 --- a/Harmony/Public/HarmonyMethod.cs +++ b/Harmony/Public/HarmonyMethod.cs @@ -292,7 +292,7 @@ static HarmonyMethod GetHarmonyMethodInfo(object attribute) public static List GetFromType(Type type) { return type.GetCustomAttributes(true) - .Select(attr => GetHarmonyMethodInfo(attr)) + .Select(GetHarmonyMethodInfo) .Where(info => info is not null) .ToList(); } @@ -310,7 +310,7 @@ public static List GetFromType(Type type) public static List GetFromMethod(MethodBase method) { return method.GetCustomAttributes(true) - .Select(attr => GetHarmonyMethodInfo(attr)) + .Select(GetHarmonyMethodInfo) .Where(info => info is not null) .ToList(); } diff --git a/Harmony/Tools/AccessTools.cs b/Harmony/Tools/AccessTools.cs index fe73ec53..cac03790 100644 --- a/Harmony/Tools/AccessTools.cs +++ b/Harmony/Tools/AccessTools.cs @@ -1,3 +1,4 @@ +using MonoMod.Core.Platforms; using MonoMod.Utils; using System; using System.Collections; @@ -84,7 +85,7 @@ public static Type[] GetTypesFromAssembly(Assembly assembly) /// Enumerates all successfully loaded types in the current app domain, excluding visual studio assemblies /// An enumeration of all in all assemblies, excluding visual studio assemblies /// - public static IEnumerable AllTypes() => AllAssemblies().SelectMany(a => GetTypesFromAssembly(a)); + public static IEnumerable AllTypes() => AllAssemblies().SelectMany(GetTypesFromAssembly); /// Enumerates all inner types (non-recursive) of a given type /// The class/type to start with @@ -133,6 +134,11 @@ public static T FindIncludingInnerTypes(Type type, Func func) where return result; } + /// Creates an identifiable version of a method + /// The method + /// + public static MethodInfo Identifiable(this MethodInfo method) => PlatformTriple.Current.GetIdentifiable(method) as MethodInfo ?? method; + /// Gets the reflection information for a directly declared field /// The class/type where the field is defined /// The name of the field diff --git a/HarmonyTests/Extras/RetrieveOriginalMethod.cs b/HarmonyTests/Extras/RetrieveOriginalMethod.cs index e2769269..5554b4a3 100644 --- a/HarmonyTests/Extras/RetrieveOriginalMethod.cs +++ b/HarmonyTests/Extras/RetrieveOriginalMethod.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics; using System.Reflection; +using System.Runtime.CompilerServices; namespace HarmonyLibTests.Extras { @@ -14,20 +15,20 @@ private static void CheckStackTraceFor(MethodBase expectedMethod) Assert.NotNull(expectedMethod); var st = new StackTrace(1, false); - var method = Harmony.GetMethodFromStackframe(st.GetFrame(0)); - - Assert.NotNull(method); - - if (method is MethodInfo replacement) - { - var original = Harmony.GetOriginalMethod(replacement); - Assert.NotNull(original); - Assert.AreEqual(original, expectedMethod); - } + var frame = st.GetFrame(0); + Assert.NotNull(frame); + + var methodFromStackframe = Harmony.GetMethodFromStackframe(frame); + Assert.NotNull(methodFromStackframe); + Assert.AreEqual(expectedMethod, methodFromStackframe); + + var replacement = frame.GetMethod() as MethodInfo; + Assert.NotNull(replacement); + var original = Harmony.GetOriginalMethod(replacement); + Assert.NotNull(original); + Assert.AreEqual(expectedMethod, original); } - /* TODO - * [Test] public void TestRegularMethod() { @@ -37,7 +38,7 @@ public void TestRegularMethod() _ = harmony.Patch(originalMethod, new HarmonyMethod(dummyPrefix)); PatchTarget(); } - + [Test] public void TestConstructor() { @@ -48,7 +49,6 @@ public void TestConstructor() var inst = new NestedClass(5); _ = inst.index; } - */ internal static void PatchTarget() { @@ -60,7 +60,7 @@ internal static void PatchTarget() } } - // [MethodImpl(MethodImplOptions.NoInlining)] + [MethodImpl(MethodImplOptions.NoInlining)] internal static void DummyPrefix() { } @@ -69,7 +69,7 @@ class NestedClass { public NestedClass(int i) { try { - CheckStackTraceFor(AccessTools.Constructor(typeof(NestedClass), [typeof(int)])); + CheckStackTraceFor(AccessTools.Constructor(typeof(NestedClass), [typeof(int)])); throw new Exception(); } catch (Exception e) {