Skip to content

Commit

Permalink
- Fixed NRE in TryBeginDispose
Browse files Browse the repository at this point in the history
- Made `OnException` event useful.
- Added some null checks where expected.
- Fixed overridden Unload not being called.
- Removed partial from AssemblyManager.cs
- Made ClearTypesList() actually work.
- Made exception details show in console on release builds.
- Made content package name show on plugin load.
- Made execution standard instead of none for autogenerated and erroneous RunConfigs.
  • Loading branch information
TBN-MapleWheels committed Oct 22, 2023
1 parent ac068aa commit 91e2ab8
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace Barotrauma;
/// Provides functionality for the loading, unloading and management of plugins implementing IAssemblyPlugin.
/// All plugins are loaded into their own AssemblyLoadContext along with their dependencies.
/// </summary>
public partial class AssemblyManager
public class AssemblyManager
{
#region ExternalAPI

Expand Down Expand Up @@ -143,16 +143,23 @@ public IEnumerable<Type> GetSubTypesInLoadedAssemblies<T>(bool rebuildList)
{
if (!_subTypesLookupCache.TryAdd(typeName, list1))
{
ModUtils.Logging.PrintError($"{nameof(AssemblyManager)}: Unable to add subtypes to cache of type {typeName}!");
ModUtils.Logging.PrintError(
$"{nameof(AssemblyManager)}: Unable to add subtypes to cache of type {typeName}!");
}
}
else
{
ModUtils.Logging.PrintMessage($"{nameof(AssemblyManager)}: Warning: No types found during search for subtypes of {typeName}");
ModUtils.Logging.PrintMessage(
$"{nameof(AssemblyManager)}: Warning: No types found during search for subtypes of {typeName}");
}

return list1;
}
catch (Exception e)
{
this.OnException?.Invoke($"{nameof(AssemblyManager)}::{nameof(GetSubTypesInLoadedAssemblies)}() | Error: {e.Message}", e);
return ImmutableList<Type>.Empty;
}
finally
{
OpsLockLoaded.ExitReadLock();
Expand Down Expand Up @@ -187,7 +194,6 @@ public bool TryGetSubTypesFromACL<T>(Guid id, out IEnumerable<Type> types)
/// </summary>
/// <param name="id"></param>
/// <param name="types"></param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public bool TryGetSubTypesFromACL(Guid id, out IEnumerable<Type> types)
{
Expand All @@ -206,7 +212,7 @@ public bool TryGetSubTypesFromACL(Guid id, out IEnumerable<Type> types)
/// Allows iteration over all types, including interfaces, in all loaded assemblies in the AsmMgr who's names match the string.
/// Note: Will return the by-reference equivalent type if the type name is prefixed with "out " or "ref ".
/// </summary>
/// <param name="name">The string name of the type to search for.</param>
/// <param name="typeName">The string name of the type to search for.</param>
/// <returns>An Enumerator for matching types. List will be empty if bad params are supplied.</returns>
public IEnumerable<Type> GetTypesByName(string typeName)
{
Expand Down Expand Up @@ -243,12 +249,19 @@ public IEnumerable<Type> GetTypesByName(string typeName)
types.Add(byRef ? t.MakeByRefType() : t);
return types;
}

foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
{
t = assembly.GetType(typeName, false, false);
if (t is not null)
types.Add(byRef ? t.MakeByRefType() : t);
try
{
t = assembly.GetType(typeName, false, false);
if (t is not null)
types.Add(byRef ? t.MakeByRefType() : t);
}
catch (Exception e)
{
this.OnException?.Invoke($"{nameof(AssemblyManager)}::{nameof(GetTypesByName)}() | Error: {e.Message}", e);
}
}
}
finally
Expand Down Expand Up @@ -316,9 +329,9 @@ public IEnumerable<Type> GetAllTypesInLoadedAssemblies()
/// <returns></returns>
public IEnumerable<LoadedACL> GetAllLoadedACLs()
{
OpsLockLoaded.EnterReadLock();
try
{
OpsLockLoaded.EnterReadLock();
return LoadedACLs.Select(kvp => kvp.Value).ToImmutableList();
}
finally
Expand Down Expand Up @@ -360,6 +373,9 @@ public AssemblyLoadingSuccessState LoadAssemblyFromMemory([NotNull] string compi
// validation
if (compiledAssemblyName.IsNullOrWhiteSpace())
return AssemblyLoadingSuccessState.BadName;

if (syntaxTree is null)
return AssemblyLoadingSuccessState.InvalidAssembly;

if (!GetOrCreateACL(id, friendlyName, out var acl))
return AssemblyLoadingSuccessState.ACLLoadFailure;
Expand Down Expand Up @@ -419,8 +435,10 @@ public AssemblyLoadingSuccessState LoadAssembliesFromLocations([NotNull] IEnumer

if (filePaths is null)
{
throw new ArgumentNullException(
var exception = new ArgumentNullException(
$"{nameof(AssemblyManager)}::{nameof(LoadAssembliesFromLocations)}() | file paths supplied is null!");
this.OnException?.Invoke($"Error: {exception.Message}", exception);
throw exception;
}

ImmutableList<string> assemblyFilePaths = filePaths.ToImmutableList(); // copy the list before loading
Expand Down Expand Up @@ -468,12 +486,15 @@ public bool TryBeginDispose()
{
if (loadedAcl.Value.Acl is not null)
{
foreach (Delegate del in IsReadyToUnloadACL.GetInvocationList())
if (IsReadyToUnloadACL is not null)
{
if (del is System.Func<LoadedACL, bool> { } func)
foreach (Delegate del in IsReadyToUnloadACL.GetInvocationList())
{
if (!func.Invoke(loadedAcl.Value))
return false; // Not ready, exit
if (del is System.Func<LoadedACL, bool> { } func)
{
if (!func.Invoke(loadedAcl.Value))
return false; // Not ready, exit
}
}
}

Expand All @@ -492,9 +513,10 @@ public bool TryBeginDispose()
LoadedACLs.Clear();
return true;
}
catch
catch(Exception e)
{
// should never happen
this.OnException?.Invoke($"{nameof(TryBeginDispose)}() | Error: {e.Message}", e);
return false;
}
finally
Expand Down Expand Up @@ -609,9 +631,9 @@ private bool GetOrCreateACL(Guid id, string friendlyName, out LoadedACL acl)
}

}
catch
catch(Exception e)
{
// should never happen but in-case
this.OnException?.Invoke($"{nameof(GetOrCreateACL)}Error: {e.Message}", e);
acl = null;
return false;
}
Expand Down Expand Up @@ -648,9 +670,9 @@ private bool DisposeACL(Guid id)

return true;
}
catch
catch (Exception e)
{
// should never happen
this.OnException?.Invoke($"{nameof(DisposeACL)}() | Error: {e.Message}", e);
return false;
}
finally
Expand All @@ -677,8 +699,9 @@ private void RebuildTypesList()
.ToImmutableDictionary(t => t.FullName ?? t.Name, t => t);
_subTypesLookupCache.Clear();
}
catch(ArgumentException _)
catch(ArgumentException ae)
{
this.OnException?.Invoke($"{nameof(RebuildTypesList)}() | Error: {ae.Message}", ae);
try
{
// some types must've had duplicate type names, build the list while filtering
Expand All @@ -699,6 +722,7 @@ private void RebuildTypesList()
}
catch (Exception e)
{
this.OnException?.Invoke($"{nameof(RebuildTypesList)}() | Error: {e.Message}", e);
ModUtils.Logging.PrintError($"{nameof(AssemblyManager)}: Unable to create list of default assembly types! Default AssemblyLoadContext types searching not available.");
#if DEBUG
ModUtils.Logging.PrintError($"{nameof(AssemblyManager)}: Exception Details :{e.Message} | {e.InnerException}");
Expand Down Expand Up @@ -729,14 +753,14 @@ public sealed class LoadedACL
public readonly Guid Id;
private ImmutableDictionary<string, Type> _assembliesTypes = ImmutableDictionary<string, Type>.Empty;
public readonly MemoryFileAssemblyContextLoader Acl;
private readonly AssemblyManager _manager;

internal LoadedACL(Guid id, AssemblyManager manager, string friendlyName)
{
this.Id = id;
this.Acl = new(manager);
this._manager = manager;
this.Acl.FriendlyName = friendlyName;
this.Acl = new(manager)
{
FriendlyName = friendlyName
};
}
public ImmutableDictionary<string, Type> AssembliesTypes => _assembliesTypes;

Expand All @@ -752,7 +776,7 @@ internal void RebuildTypesList()
.SelectMany(a => a.GetSafeTypes())
.ToImmutableDictionary(t => t.FullName ?? t.Name, t => t);
}
catch(ArgumentException _)
catch(ArgumentException)
{
// some types must've had duplicate type names, build the list while filtering
Dictionary<string, Type> types = new();
Expand All @@ -774,7 +798,7 @@ internal void RebuildTypesList()

internal void ClearTypesList()
{
_assembliesTypes.Clear();
_assembliesTypes = ImmutableDictionary<string, Type>.Empty;
}
}

Expand Down Expand Up @@ -802,12 +826,12 @@ public static IEnumerable<Type> GetSafeTypes(this Assembly assembly)
{
return re.Types.Where(x => x != null)!;
}
catch (InvalidOperationException ioe)
catch (InvalidOperationException)
{
return new List<Type>();
}
}
catch (Exception e)
catch (Exception)
{
return new List<Type>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using MonoMod.Utils;
// ReSharper disable InconsistentNaming

namespace Barotrauma;

Expand Down Expand Up @@ -147,12 +148,12 @@ public bool LuaTryRegisterPackageTypes(string name, bool caseSensitive = false)
/// <summary>
/// Whether or not plugins' types have been instantiated.
/// </summary>
public bool PluginsInitialized { get; private set; } = false;
public bool PluginsInitialized { get; private set; }

/// <summary>
/// Whether or not plugins are fully loaded.
/// </summary>
public bool PluginsLoaded { get; private set; } = false;
public bool PluginsLoaded { get; private set; }

public IEnumerable<ContentPackage> GetCurrentPackagesByLoadOrder() => _currentPackagesByLoadOrder;

Expand Down Expand Up @@ -333,7 +334,7 @@ public AssemblyLoadingSuccessState LoadAssemblyPackages()
throw new DirectoryNotFoundException("No publicized assemblies found.");
}
// no directory found, use the other one
catch (DirectoryNotFoundException dne)
catch (DirectoryNotFoundException)
{
if (_luaCsSetup.Config.PreferToUseWorkshopLuaSetup)
{
Expand Down Expand Up @@ -454,8 +455,7 @@ public AssemblyLoadingSuccessState LoadAssemblyPackages()
if (reliableMap && OrderAndFilterPackagesByDependencies(
_packagesDependencies,
out var readyToLoad,
out var cannotLoadPackages,
null))
out var cannotLoadPackages))
{
packagesToLoadInOrder.AddRange(readyToLoad);
if (cannotLoadPackages is not null)
Expand Down Expand Up @@ -611,21 +611,19 @@ public AssemblyLoadingSuccessState LoadAssemblyPackages()

bool ShouldRunPackage(ContentPackage package, RunConfig config)
{
if (config.AutoGenerated)
return false;
return (!_luaCsSetup.Config.TreatForcedModsAsNormal && config.IsForced())
|| (ContentPackageManager.EnabledPackages.All.Contains(package) && config.IsForcedOrStandard());
}

void UpdatePackagesToDisable(ref HashSet<ContentPackage> list,
void UpdatePackagesToDisable(ref HashSet<ContentPackage> set,
ContentPackage newDisabledPackage,
IEnumerable<KeyValuePair<ContentPackage, ImmutableList<ContentPackage>>> dependenciesMap)
{
list.Add(newDisabledPackage);
set.Add(newDisabledPackage);
foreach (var package in dependenciesMap)
{
if (package.Value.Contains(newDisabledPackage))
list.Add(newDisabledPackage);
set.Add(newDisabledPackage);
}
}
}
Expand Down Expand Up @@ -655,7 +653,7 @@ public void RunPluginsInit()
// init
foreach (var plugin in contentPlugins.Value)
{
TryRun(() => plugin.Initialize(), $"{nameof(IAssemblyPlugin.Initialize)}", plugin.GetType().Name);
TryRun(() => plugin.Initialize(), $"{nameof(IAssemblyPlugin.Initialize)}", $"CP: {_reverseLookupGuidList[contentPlugins.Key].Name} Plugin: {plugin.GetType().Name}");
}
}

Expand All @@ -664,7 +662,7 @@ public void RunPluginsInit()
// load complete
foreach (var plugin in contentPlugins.Value)
{
TryRun(() => plugin.OnLoadCompleted(), $"{nameof(IAssemblyPlugin.OnLoadCompleted)}", plugin.GetType().Name);
TryRun(() => plugin.OnLoadCompleted(), $"{nameof(IAssemblyPlugin.OnLoadCompleted)}", $"CP: {_reverseLookupGuidList[contentPlugins.Key].Name} Plugin: {plugin.GetType().Name}");
}
}

Expand Down Expand Up @@ -698,7 +696,7 @@ public void RunPluginsPreInit()
// init
foreach (var plugin in contentPlugins.Value)
{
TryRun(() => plugin.PreInitPatching(), $"{nameof(IAssemblyPlugin.PreInitPatching)}", plugin.GetType().Name);
TryRun(() => plugin.PreInitPatching(), $"{nameof(IAssemblyPlugin.PreInitPatching)}", $"CP: {_reverseLookupGuidList[contentPlugins.Key].Name} Plugin: {plugin.GetType().Name}");
}
}

Expand Down Expand Up @@ -741,21 +739,20 @@ public void InstantiatePlugins(bool force = false)
try
{
plugin = (IAssemblyPlugin)Activator.CreateInstance(type);
_loadedPlugins[pair.Key].Add(plugin);
}
catch (Exception e)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Error while instantiating plugin of type {type}. Now disposing...");
#if DEBUG
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Details: {e.Message} | {e.InnerException}");
#endif
TryRun(() => plugin?.Dispose(), "Dispose", type.FullName ?? type.Name);

plugin = null;
if (plugin is not null)
{
// ReSharper disable once AccessToModifiedClosure
TryRun(() => plugin?.Dispose(), nameof(IAssemblyPlugin.Dispose), type.FullName ?? type.Name);
plugin = null;
}
}
if (plugin is not null)
_loadedPlugins[pair.Key].Add(plugin);
else
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Error while instantiating plugin of type {type}");
}
}

Expand All @@ -772,7 +769,7 @@ public void UnloadPlugins()
{
foreach (var plugin in contentPlugins.Value)
{
TryRun(() => plugin.Dispose(), $"{nameof(IAssemblyPlugin.Dispose)}", plugin.GetType().Name);
TryRun(() => plugin.Dispose(), $"{nameof(IAssemblyPlugin.Dispose)}", $"CP: {_reverseLookupGuidList[contentPlugins.Key].Name} Plugin: {plugin.GetType().Name}");
}
contentPlugins.Value.Clear();
}
Expand Down Expand Up @@ -816,9 +813,7 @@ private void TryRun(Action action, string messageMethodName, string messageTypeN
catch (Exception e)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Error while running {messageMethodName}() on plugin of type {messageTypeName}");
#if DEBUG
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Details: {e.Message} | {e.InnerException}");
#endif
}
}

Expand Down
Loading

0 comments on commit 91e2ab8

Please sign in to comment.