Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/get method from stackframe #601

Merged
merged 4 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Harmony/Internal/CodeTranspiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ internal static IEnumerable ConvertToGeneralInstructions(MethodInfo transpiler,
{
var type = transpiler.GetParameters()
.Select(p => p.ParameterType)
.FirstOrDefault(t => IsCodeInstructionsParameter(t));
.FirstOrDefault(IsCodeInstructionsParameter);
if (type == typeof(IEnumerable<CodeInstruction>))
{
unassignedValues = null;
Expand Down
69 changes: 61 additions & 8 deletions Harmony/Internal/HarmonySharedState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,24 @@ internal static class HarmonySharedState
const string name = "HarmonySharedState";
internal const int internalVersion = 102; // bump this if the layout of the HarmonySharedState type changes

// state/originals/methodStarts are set to instances stored in the global dynamic types static fields with the same name
// state/originals/originalsMono are set to instances stored in the global dynamic types static fields with the same name
static readonly Dictionary<MethodBase, byte[]> state;
static readonly Dictionary<MethodInfo, MethodBase> originals;

static readonly Dictionary<long, MethodBase[]> originalsMono;

static readonly AccessTools.FieldRef<StackFrame, long> methodAddressRef;

internal static readonly int actualVersion;

static HarmonySharedState()
{
// create singleton type
var type = GetOrCreateSharedStateType();

// this field is useed to find methods from stackframes in Mono
if (AccessTools.IsMonoRuntime && AccessTools.Field(typeof(StackFrame), "methodAddress") is FieldInfo field)
methodAddressRef = AccessTools.FieldRefAccess<StackFrame, long>(field);

// copy 'actualVersion' over to our fields
var versionField = type.GetField("version");
if ((int)versionField.GetValue(null) == 0)
Expand All @@ -62,13 +69,23 @@ static HarmonySharedState()
if (originalsField != null && originalsField.GetValue(null) is null)
originalsField.SetValue(null, new Dictionary<MethodInfo, MethodBase>());

// get or initialize global 'originalsMono' field
var originalsMonoField = type.GetField("originalsMono");
if (originalsMonoField != null && originalsMonoField.GetValue(null) is null)
originalsMonoField.SetValue(null, new Dictionary<long, MethodBase[]>());

// copy 'state' over to our fields
state = (Dictionary<MethodBase, byte[]>)stateField.GetValue(null);

// copy 'originals' over to our fields
originals = [];
if (originalsField != null) // may not exist in older versions
originals = (Dictionary<MethodInfo, MethodBase>)originalsField.GetValue(null);

// copy 'originalsMono' over to our fields
originalsMono = [];
if (originalsMonoField != null) // may not exist in older versions
originalsMono = (Dictionary<long, MethodBase[]>)originalsMonoField.GetValue(null);
}

// creates a dynamic 'global' type if it does not exist
Expand All @@ -94,6 +111,12 @@ static Type GetOrCreateSharedStateType()
module.ImportReference(typeof(Dictionary<MethodInfo, MethodBase>))
));

typedef.Fields.Add(new FieldDefinition(
"originalsMono",
Mono.Cecil.FieldAttributes.Public | Mono.Cecil.FieldAttributes.Static,
module.ImportReference(typeof(Dictionary<long, MethodBase[]>))
));

typedef.Fields.Add(new FieldDefinition(
"version",
Mono.Cecil.FieldAttributes.Public | Mono.Cecil.FieldAttributes.Static,
Expand Down Expand Up @@ -122,19 +145,49 @@ internal static void UpdatePatchInfo(MethodBase original, MethodInfo replacement
{
var bytes = patchInfo.Serialize();
lock (state) state[original] = bytes;
lock (originals) originals[replacement] = original;
lock (originals) originals[replacement.Identifiable()] = original;
if (AccessTools.IsMonoRuntime)
{
var methodAddress = (long)replacement.MethodHandle.GetFunctionPointer();
lock (originalsMono) originalsMono[methodAddress] = [original, replacement];
}
}

internal static MethodBase GetOriginal(MethodInfo replacement)
// With mono, useReplacement is used to either return the original or the replacement
// On .NET, useReplacement is ignored and the original is always returned
internal static MethodBase GetRealMethod(MethodInfo method, bool useReplacement)
{
lock (originals) return originals.GetValueSafe(replacement);
var identifiableMethod = method.Identifiable();
lock (originals)
if (originals.TryGetValue(identifiableMethod, out var original))
return original;

if (AccessTools.IsMonoRuntime)
{
var methodAddress = (long)method.MethodHandle.GetFunctionPointer();
lock (originalsMono)
if (originalsMono.TryGetValue(methodAddress, out var info))
return useReplacement ? info[1] : info[0];
}

return method;
}

internal static MethodBase FindReplacement(StackFrame frame)
internal static MethodBase GetStackFrameMethod(StackFrame frame, bool useReplacement)
{
var method = frame.GetMethod() as MethodInfo;
if (method == null) return null;
return GetOriginal(method);
if (method != null)
return GetRealMethod(method, useReplacement);

if (methodAddressRef != null)
{
var methodAddress = methodAddressRef(frame);
lock (originalsMono)
if (originalsMono.TryGetValue(methodAddress, out var info))
return useReplacement ? info[1] : info[0];
}

return null;
}
}
}
2 changes: 1 addition & 1 deletion Harmony/Internal/MethodCopier.cs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ internal List<CodeInstruction> FinalizeILCodes(Emitter emitter, List<MethodInfo>
// pass2 - filter through all processors
//
var codeTranspiler = new CodeTranspiler(ilInstructions);
transpilers.Do(transpiler => codeTranspiler.Add(transpiler));
transpilers.Do(codeTranspiler.Add);
var codeInstructions = codeTranspiler.GetResult(generator, method);

if (emitter is null)
Expand Down
2 changes: 1 addition & 1 deletion Harmony/Internal/MethodPatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ internal MethodInfo CreateReplacement(out Dictionary<int, CodeInstruction> final

Label? skipOriginalLabel = null;
LocalBuilder runOriginalVariable = null;
var prefixAffectsOriginal = prefixes.Any(fix => PrefixAffectsOriginal(fix));
var prefixAffectsOriginal = prefixes.Any(PrefixAffectsOriginal);
var anyFixHasRunOriginalVar = fixes.Any(fix => fix.GetParameters().Any(p => p.Name == RUN_ORIGINAL_VAR));
if (prefixAffectsOriginal || anyFixHasRunOriginalVar)
{
Expand Down
2 changes: 1 addition & 1 deletion Harmony/Internal/PatchModels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ internal static AttributePatch Create(MethodInfo patch)
var f_info = AccessTools.Field(attr.GetType(), nameof(HarmonyAttribute.info));
return f_info.GetValue(attr);
})
.Select(harmonyInfo => AccessTools.MakeDeepCopy<HarmonyMethod>(harmonyInfo))
.Select(AccessTools.MakeDeepCopy<HarmonyMethod>)
.ToList();
var info = HarmonyMethod.Merge(list);
info.method = patch;
Expand Down
4 changes: 2 additions & 2 deletions Harmony/Internal/PatchTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ internal static void DetourMethod(MethodBase method, MethodBase replacement)
static Assembly GetExecutingAssemblyReplacement()
{
var frames = new StackTrace().GetFrames();
if (frames?.Skip(1).FirstOrDefault() is { } frame && Harmony.GetOriginalMethodFromStackframe(frame) is { } original)
if (frames?.Skip(1).FirstOrDefault() is { } frame && Harmony.GetMethodFromStackframe(frame) is { } original)
return original.Module.Assembly;
return Assembly.GetExecutingAssembly();
}
Expand Down Expand Up @@ -78,7 +78,7 @@ internal static AssemblyBuilder DefineDynamicAssembly(string name)
internal static List<AttributePatch> GetPatchMethods(Type type)
{
return AccessTools.GetDeclaredMethods(type)
.Select(method => AttributePatch.Create(method))
.Select(AttributePatch.Create)
.Where(attributePatch => attributePatch is not null)
.ToList();
}
Expand Down
20 changes: 8 additions & 12 deletions Harmony/Public/Harmony.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using MonoMod.Core.Platforms;
using System;
using System.Collections.Generic;
using System.Diagnostics;
Expand Down Expand Up @@ -227,7 +226,7 @@ public void Unpatch(MethodBase original, MethodInfo patch)
public static bool HasAnyPatches(string harmonyID)
{
return GetAllPatchedMethods()
.Select(original => GetPatchInfo(original))
.Select(GetPatchInfo)
.Any(info => info.Owners.Contains(harmonyID));
}

Expand All @@ -252,15 +251,13 @@ public IEnumerable<MethodBase> GetPatchedMethods()
public static IEnumerable<MethodBase> GetAllPatchedMethods() => PatchProcessor.GetAllPatchedMethods();

/// <summary>Gets the original method from a given replacement method</summary>
/// <param name="replacement">A replacement method, for example from a stacktrace</param>
/// <param name="replacement">A replacement method (patched original method)</param>
/// <returns>The original method/constructor or <c>null</c> if not found</returns>
///
public static MethodBase GetOriginalMethod(MethodInfo replacement)
{
if (replacement == null) throw new ArgumentNullException(nameof(replacement));
// The runtime can return several different MethodInfo's that point to the same method. Use the correct one
var identifiableReplacement = PlatformTriple.Current.GetIdentifiable(replacement) as MethodInfo;
return HarmonySharedState.GetOriginal(identifiableReplacement);
return HarmonySharedState.GetRealMethod(replacement, useReplacement: false);
}

/// <summary>Tries to get the method from a stackframe including dynamic replacement methods</summary>
Expand All @@ -270,24 +267,23 @@ public static MethodBase GetOriginalMethod(MethodInfo replacement)
public static MethodBase GetMethodFromStackframe(StackFrame frame)
{
if (frame == null) throw new ArgumentNullException(nameof(frame));
return HarmonySharedState.FindReplacement(frame) ?? frame.GetMethod();
return HarmonySharedState.GetStackFrameMethod(frame, useReplacement: true);
}

/// <summary>Gets the original method from the stackframe and uses original if method is a dynamic replacement</summary>
/// <param name="frame">The <see cref="StackFrame"/></param>
/// <returns>The original method from that stackframe</returns>
public static MethodBase GetOriginalMethodFromStackframe(StackFrame frame)
{
var member = GetMethodFromStackframe(frame);
if (member is MethodInfo methodInfo)
member = GetOriginalMethod(methodInfo) ?? member;
return member;
if (frame == null) throw new ArgumentNullException(nameof(frame));
return HarmonySharedState.GetStackFrameMethod(frame, useReplacement: false);
}

/// <summary>Gets Harmony version for all active Harmony instances</summary>
/// <param name="currentVersion">[out] The current Harmony version</param>
/// <returns>A dictionary containing assembly versions keyed by Harmony IDs</returns>
///
public static Dictionary<string, Version> VersionInfo(out Version currentVersion) => PatchProcessor.VersionInfo(out currentVersion);
public static Dictionary<string, Version> VersionInfo(out Version currentVersion)
=> PatchProcessor.VersionInfo(out currentVersion);
}
}
4 changes: 2 additions & 2 deletions Harmony/Public/HarmonyMethod.cs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ static HarmonyMethod GetHarmonyMethodInfo(object attribute)
public static List<HarmonyMethod> GetFromType(Type type)
{
return type.GetCustomAttributes(true)
.Select(attr => GetHarmonyMethodInfo(attr))
.Select(GetHarmonyMethodInfo)
.Where(info => info is not null)
.ToList();
}
Expand All @@ -310,7 +310,7 @@ public static List<HarmonyMethod> GetFromType(Type type)
public static List<HarmonyMethod> GetFromMethod(MethodBase method)
{
return method.GetCustomAttributes(true)
.Select(attr => GetHarmonyMethodInfo(attr))
.Select(GetHarmonyMethodInfo)
.Where(info => info is not null)
.ToList();
}
Expand Down
8 changes: 7 additions & 1 deletion Harmony/Tools/AccessTools.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using MonoMod.Core.Platforms;
using MonoMod.Utils;
using System;
using System.Collections;
Expand Down Expand Up @@ -84,7 +85,7 @@ public static Type[] GetTypesFromAssembly(Assembly assembly)
/// <summary>Enumerates all successfully loaded types in the current app domain, excluding visual studio assemblies</summary>
/// <returns>An enumeration of all <see cref="Type"/> in all assemblies, excluding visual studio assemblies</returns>
///
public static IEnumerable<Type> AllTypes() => AllAssemblies().SelectMany(a => GetTypesFromAssembly(a));
public static IEnumerable<Type> AllTypes() => AllAssemblies().SelectMany(GetTypesFromAssembly);

/// <summary>Enumerates all inner types (non-recursive) of a given type</summary>
/// <param name="type">The class/type to start with</param>
Expand Down Expand Up @@ -133,6 +134,11 @@ public static T FindIncludingInnerTypes<T>(Type type, Func<Type, T> func) where
return result;
}

/// <summary>Creates an identifiable version of a method</summary>
/// <param name="method">The method</param>
/// <returns></returns>
public static MethodInfo Identifiable(this MethodInfo method) => PlatformTriple.Current.GetIdentifiable(method) as MethodInfo ?? method;

/// <summary>Gets the reflection information for a directly declared field</summary>
/// <param name="type">The class/type where the field is defined</param>
/// <param name="name">The name of the field</param>
Expand Down
32 changes: 16 additions & 16 deletions HarmonyTests/Extras/RetrieveOriginalMethod.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Diagnostics;
using System.Reflection;
using System.Runtime.CompilerServices;

namespace HarmonyLibTests.Extras
{
Expand All @@ -14,20 +15,20 @@ private static void CheckStackTraceFor(MethodBase expectedMethod)
Assert.NotNull(expectedMethod);

var st = new StackTrace(1, false);
var method = Harmony.GetMethodFromStackframe(st.GetFrame(0));

Assert.NotNull(method);

if (method is MethodInfo replacement)
{
var original = Harmony.GetOriginalMethod(replacement);
Assert.NotNull(original);
Assert.AreEqual(original, expectedMethod);
}
var frame = st.GetFrame(0);
Assert.NotNull(frame);

var methodFromStackframe = Harmony.GetMethodFromStackframe(frame);
Assert.NotNull(methodFromStackframe);
Assert.AreEqual(expectedMethod, methodFromStackframe);

var replacement = frame.GetMethod() as MethodInfo;
Assert.NotNull(replacement);
var original = Harmony.GetOriginalMethod(replacement);
Assert.NotNull(original);
Assert.AreEqual(expectedMethod, original);
}

/* TODO
*
[Test]
public void TestRegularMethod()
{
Expand All @@ -37,7 +38,7 @@ public void TestRegularMethod()
_ = harmony.Patch(originalMethod, new HarmonyMethod(dummyPrefix));
PatchTarget();
}

[Test]
public void TestConstructor()
{
Expand All @@ -48,7 +49,6 @@ public void TestConstructor()
var inst = new NestedClass(5);
_ = inst.index;
}
*/

internal static void PatchTarget()
{
Expand All @@ -60,7 +60,7 @@ internal static void PatchTarget()
}
}

// [MethodImpl(MethodImplOptions.NoInlining)]
[MethodImpl(MethodImplOptions.NoInlining)]
internal static void DummyPrefix()
{
}
Expand All @@ -69,7 +69,7 @@ class NestedClass {
public NestedClass(int i)
{
try {
CheckStackTraceFor(AccessTools.Constructor(typeof(NestedClass), [typeof(int)]));
CheckStackTraceFor(AccessTools.Constructor(typeof(NestedClass), [typeof(int)]));
throw new Exception();
} catch (Exception e)
{
Expand Down
Loading