diff --git a/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterClient.lua b/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterClient.lua index 9b5460b35c..d0497fd6c4 100644 --- a/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterClient.lua +++ b/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterClient.lua @@ -24,11 +24,8 @@ RegisterBarotrauma("Media.Video") RegisterBarotrauma("SoundsFile") RegisterBarotrauma("SoundPrefab") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.SoundPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.BackgroundMusic]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.GUISound]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.DamageSound]]") -RegisterBarotrauma("PrefabSelector`1[[Barotrauma.SoundPrefab]]") +RegisterBarotrauma("PrefabCollection`1") +RegisterBarotrauma("PrefabSelector`1") RegisterBarotrauma("BackgroundMusic") RegisterBarotrauma("GUISound") RegisterBarotrauma("DamageSound") @@ -57,7 +54,6 @@ RegisterBarotrauma("Particles.Particle") RegisterBarotrauma("Particles.ParticleEmitterProperties") RegisterBarotrauma("Particles.ParticleEmitter") RegisterBarotrauma("Particles.ParticlePrefab") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.Particles.ParticlePrefab]]") RegisterBarotrauma("Lights.LightManager") RegisterBarotrauma("Lights.LightSource") @@ -146,4 +142,4 @@ RegisterBarotrauma("UISprite") RegisterBarotrauma("ParamsEditor") RegisterBarotrauma("Inventory+SlotReference") -RegisterBarotrauma("VisualSlot") \ No newline at end of file +RegisterBarotrauma("VisualSlot") diff --git a/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterShared.lua b/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterShared.lua index f8c34e511c..0b643e0314 100644 --- a/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterShared.lua +++ b/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterShared.lua @@ -6,8 +6,8 @@ Register("System.Exception") Register("System.Console") Register("System.Exception") -RegisterBarotrauma("Success`2[[Barotrauma.ContentPackage],[System.Exception]]") -RegisterBarotrauma("Failure`2[[Barotrauma.ContentPackage],[System.Exception]]") +RegisterBarotrauma("Success`2") +RegisterBarotrauma("Failure`2") RegisterBarotrauma("LuaSByte") RegisterBarotrauma("LuaByte") @@ -24,8 +24,7 @@ RegisterBarotrauma("GameMain") RegisterBarotrauma("Networking.BanList") RegisterBarotrauma("Networking.BannedPlayer") -RegisterBarotrauma("Range`1[System.Single]") -RegisterBarotrauma("Range`1[System.Int32]") +RegisterBarotrauma("Range`1") RegisterBarotrauma("RichString") RegisterBarotrauma("Identifier") @@ -399,27 +398,11 @@ end RegisterBarotrauma("Camera") RegisterBarotrauma("Key") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.ItemPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.JobPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.CharacterPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.AfflictionPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.TalentPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.TalentTree]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.OrderPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.LevelGenerationParams]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.OutpostGenerationParams]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.RuinGeneration.RuinGenerationParams]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.LevelGenerationParams]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.LocationType]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.EventPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.EventSet]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.EventManagerSettings]]") +RegisterBarotrauma("PrefabCollection`1") -RegisterBarotrauma("PrefabSelector`1[[Barotrauma.SkillSettings]]") +RegisterBarotrauma("PrefabSelector`1") -RegisterBarotrauma("Pair`2[[Barotrauma.JobPrefab],[System.Int32]]") - -RegisterBarotrauma("Range`1[System.Single]") +RegisterBarotrauma("Pair`2") RegisterBarotrauma("Items.Components.Signal") RegisterBarotrauma("SubmarineInfo") @@ -461,4 +444,4 @@ LuaUserData.RemoveMember(workshopItem, "AddFavorite") LuaUserData.RemoveMember(workshopItem, "RemoveFavorite") LuaUserData.RemoveMember(workshopItem, "Vote") LuaUserData.RemoveMember(workshopItem, "GetUserVote") -LuaUserData.RemoveMember(workshopItem, "Edit") \ No newline at end of file +LuaUserData.RemoveMember(workshopItem, "Edit") diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptBase.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptBase.cs deleted file mode 100644 index 6ac4b8f9e6..0000000000 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptBase.cs +++ /dev/null @@ -1,50 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using Microsoft.CodeAnalysis.Scripting; -using System.Reflection; -using Microsoft.CodeAnalysis.CSharp; -using System.Linq; -using Microsoft.CodeAnalysis; -using System.Runtime.Loader; -using System.Reflection.PortableExecutable; -using System.Reflection.Metadata; -using System.Text; -using System.Runtime.CompilerServices; - -namespace Barotrauma -{ - class CsScriptBase : AssemblyLoadContext - { - - public const string CsScriptAssembly = "NetScriptAssembly"; - - public static readonly string[] LoadedAssemblyName = { - CsScriptBase.CsScriptAssembly - }; - - public static Dictionary Revision = new Dictionary() - { - { CsScriptAssembly, 0} - }; - - public CSharpParseOptions ParseOptions { get; protected set; } - - public CsScriptBase() : base(isCollectible: true) { - ParseOptions = CSharpParseOptions.Default - .WithPreprocessorSymbols(new[] { LuaCsSetup.IsServer ? "SERVER" : (LuaCsSetup.IsClient ? "CLIENT" : "UNDEFINED") }); - } - - public static SyntaxTree AssemblyInfoSyntaxTree(string asmName = null) - { - Revision[asmName] = (int)Revision[asmName] + 1; - var asmInfo = new StringBuilder(); - asmInfo.AppendLine("using System.Reflection;"); - asmInfo.AppendLine($"[assembly: AssemblyMetadata(\"Revision\", \"{Revision[asmName]}\")]"); - asmInfo.AppendLine($"[assembly: AssemblyVersion(\"0.0.0.{Revision[asmName]}\")]"); - return CSharpSyntaxTree.ParseText(asmInfo.ToString(), CSharpParseOptions.Default); - } - - ~CsScriptBase() { } - } -} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs deleted file mode 100644 index 90d858f822..0000000000 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs +++ /dev/null @@ -1,286 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using Microsoft.CodeAnalysis.Scripting; -using System.Reflection; -using Microsoft.CodeAnalysis.CSharp; -using System.Linq; -using Microsoft.CodeAnalysis; -using System.Runtime.Loader; -using System.Reflection.PortableExecutable; -using System.Reflection.Metadata; -using System.Text.RegularExpressions; -using System.Xml.Linq; - -namespace Barotrauma -{ - class CsScriptLoader : CsScriptBase - { - private List defaultReferences; - - private Dictionary> sources; - public Assembly Assembly { get; private set; } - - public CsScriptLoader() - { - defaultReferences = AppDomain.CurrentDomain.GetAssemblies() - .Where(a => !(a.IsDynamic || string.IsNullOrEmpty(a.Location) || a.Location.Contains("xunit"))) - .Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference) - .ToList(); - - sources = new Dictionary>(); - Assembly = null; - } - - private enum RunType { Standard, Forced, None }; - private bool ShouldRun(ContentPackage cp, string path) - { - if (!Directory.Exists(path + "CSharp")) - { - return false; - } - - var isEnabled = ContentPackageManager.EnabledPackages.All.Contains(cp); - if (File.Exists(path + "CSharp/RunConfig.xml")) - { - Stream stream = File.Open(path + "CSharp/RunConfig.xml", FileMode.Open, FileAccess.Read, FileShare.ReadWrite); - var doc = XDocument.Load(stream); - var elems = doc.Root.Elements().ToArray(); - var elem = elems.FirstOrDefault(e => e.Name.LocalName.Equals(LuaCsSetup.IsServer ? "Server" : (LuaCsSetup.IsClient ? "Client" : "None"), StringComparison.OrdinalIgnoreCase)); - - if (elem != null && Enum.TryParse(elem.Value, true, out RunType rtValue)) - { - if (rtValue == RunType.Standard && isEnabled) - { - LuaCsLogger.LogMessage($"Added {cp.Name} {cp.ModVersion} to Cs compilation. (Standard)"); - return true; - } - else if (rtValue == RunType.Forced && (isEnabled || !GameMain.LuaCs.Config.TreatForcedModsAsNormal)) - { - LuaCsLogger.LogMessage($"Added {cp.Name} {cp.ModVersion} to Cs compilation. (Forced)"); - return true; - } - else if (rtValue == RunType.None) - { - return false; - } - } - - stream.Close(); - } - - if (isEnabled) - { - LuaCsLogger.LogMessage($"Added {cp.Name} {cp.ModVersion} to Cs compilation. (Assumed)"); - return true; - } - else - { - return false; - } - } - - public void SearchFolders() - { - var packagesAdded = new HashSet(); - var paths = new Dictionary(); - foreach (var cp in ContentPackageManager.AllPackages.Concat(ContentPackageManager.EnabledPackages.All)) - { - if (packagesAdded.Contains(cp)) { continue; } - var path = $"{Path.GetFullPath(Path.GetDirectoryName(cp.Path)).Replace('\\', '/')}/"; - if (ShouldRun(cp, path)) - { - if (paths.ContainsKey(cp.Name)) - { - if (ContentPackageManager.EnabledPackages.All.Contains(cp)) - { - paths[cp.Name] = path; - } - } - else - { - paths.Add(cp.Name, path); - } - packagesAdded.Add(cp); - } - } - - foreach ((var _, var path) in paths) - { - RunFolder(path); - } - } - - public bool HasSources { get => sources.Count > 0; } - - private void AddSources(string folder) - { - foreach (var str in DirSearch(folder)) - { - string s = str.Replace("\\", "/"); - - if (sources.ContainsKey(folder)) - { - sources[folder].Add(s); - } - else - { - sources.Add(folder, new List { s }); - } - } - } - - private void RunFolder(string folder) - { - - AddSources(folder + "/CSharp/Shared"); - -#if SERVER - AddSources(folder + "/CSharp/Server"); -#else - AddSources(folder + "/CSharp/Client"); -#endif - } - - private IEnumerable ParseSources() { - var syntaxTrees = new List(); - - if (sources.Count <= 0) throw new Exception("No Cs sources detected"); - syntaxTrees.Add(AssemblyInfoSyntaxTree(CsScriptAssembly)); - foreach ((var folder, var src) in sources) - { - try - { - foreach (var file in src) - { - var tree = SyntaxFactory.ParseSyntaxTree(File.ReadAllText(file), ParseOptions, file); - - syntaxTrees.Add(tree); - } - } - catch (Exception ex) - { - LuaCsLogger.LogError("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace, LuaCsMessageOrigin.CSharpMod); - } - } - - return syntaxTrees; - } - - private ContentPackage FindSourcePackage(Diagnostic diagnostic) - { - if (diagnostic.Location.SourceTree == null) - { - return null; - } - - string path = diagnostic.Location.SourceTree.FilePath; - foreach (var package in ContentPackageManager.AllPackages) - { - if (Path.GetFullPath(path).StartsWith(Path.GetFullPath(package.Dir))) - { - return package; - } - } - - return null; - } - - public List Compile() - { - IEnumerable syntaxTrees = ParseSources(); - - var options = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary) - .WithMetadataImportOptions(MetadataImportOptions.All) - .WithOptimizationLevel(OptimizationLevel.Release) - .WithAllowUnsafe(true); - - var topLevelBinderFlagsProperty = typeof(CSharpCompilationOptions).GetProperty("TopLevelBinderFlags", BindingFlags.Instance | BindingFlags.NonPublic); - topLevelBinderFlagsProperty.SetValue(options, (uint)1 << 22); - - var compilation = CSharpCompilation.Create(CsScriptAssembly, syntaxTrees, defaultReferences, options); - - using (var mem = new MemoryStream()) - { - var result = compilation.Emit(mem); - if (!result.Success) - { - IEnumerable failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error); - - string errStr = "CS MODS NOT LOADED | Compilation errors:"; - foreach (Diagnostic diagnostic in failures) - { - errStr += $"\n{diagnostic}"; -#if CLIENT - ContentPackage package = FindSourcePackage(diagnostic); - if (package != null) - { - LuaCsLogger.ShowErrorOverlay($"{package.Name} {package.ModVersion} is causing compilation errors. Check debug console for more details.", 7f, 7f); - } -#endif - } - LuaCsLogger.LogError(errStr, LuaCsMessageOrigin.CSharpMod); - } - else - { - mem.Seek(0, SeekOrigin.Begin); - Assembly = LoadFromStream(mem); - } - } - - if (Assembly != null) - { - RegisterAssemblyWithNativeGame(Assembly); - try - { - return Assembly.GetTypes().Where(t => t.IsSubclassOf(typeof(ACsMod))).ToList(); - } - catch (ReflectionTypeLoadException re) - { - LuaCsLogger.LogError($"Unable to load CsMod Types. {re.Message}", LuaCsMessageOrigin.CSharpMod); - throw re; - } - } - else - { - throw new Exception("Unable to create cs mods assembly."); - } - } - - /// - /// This function should be used whenever a new assembly is created. Wrapper to allow more complicated setup later if need be. - /// - private static void RegisterAssemblyWithNativeGame(Assembly assembly) - { - Barotrauma.ReflectionUtils.AddNonAbstractAssemblyTypes(assembly); - } - - /// - /// This function should be used whenever a new assembly is about to be destroyed/unloaded. Wrapper to allow more complicated setup later if need be. - /// - /// Assembly to remove - private static void UnregisterAssemblyFromNativeGame(Assembly assembly) - { - Barotrauma.ReflectionUtils.RemoveAssemblyFromCache(assembly); - } - - private static string[] DirSearch(string sDir) - { - if (!Directory.Exists(sDir)) - { - return new string[] {}; - } - - return Directory.GetFiles(sDir, "*.cs", SearchOption.AllDirectories); - } - - public void Clear() - { - if (Assembly != null) - { - UnregisterAssemblyFromNativeGame(Assembly); - Assembly = null; - } - } - } -} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaProxy.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaProxy.cs deleted file mode 100644 index 2f2ca12e68..0000000000 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaProxy.cs +++ /dev/null @@ -1,8 +0,0 @@ -using System; -using MoonSharp.Interpreter; -using Barotrauma.Networking; - -namespace Barotrauma -{ - -} \ No newline at end of file diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs index 6c89d46802..10988b54de 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs @@ -9,8 +9,8 @@ using System.Diagnostics; using MoonSharp.VsCodeDebugger; using System.Reflection; +using System.Runtime.Loader; -[assembly: InternalsVisibleTo(Barotrauma.CsScriptBase.CsScriptAssembly, AllInternalsVisible = true)] namespace Barotrauma { class LuaCsSetupConfig @@ -66,12 +66,18 @@ partial class LuaCsSetup public LuaCsSteam Steam { get; private set; } public LuaCsPerformanceCounter PerformanceCounter { get; private set; } + // must be available at anytime + private static AssemblyManager _assemblyManager; + public static AssemblyManager AssemblyManager => _assemblyManager ??= new AssemblyManager(); + + private CsPackageManager _pluginPackageManager; + public CsPackageManager PluginPackageManager => _pluginPackageManager ??= new CsPackageManager(AssemblyManager, this); + public LuaCsModStore ModStore { get; private set; } private LuaRequire require { get; set; } - - public CsScriptLoader CsScriptLoader { get; private set; } public LuaCsSetupConfig Config { get; private set; } public MoonSharpVsCodeDebugServer DebugServer { get; private set; } + public bool IsInitialized { get; private set; } private bool ShouldRunCs { @@ -90,7 +96,6 @@ public LuaCsSetup() Game = new LuaGame(); Networking = new LuaCsNetworking(); - DebugServer = new MoonSharpVsCodeDebugServer(); if (File.Exists(configFileName)) @@ -105,35 +110,11 @@ public LuaCsSetup() Config = new LuaCsSetupConfig(); } } - + + [Obsolete("Use AssemblyManager::GetTypesByName()")] public static Type GetType(string typeName, bool throwOnError = false, bool ignoreCase = false) { - if (typeName == null || typeName.Length == 0) { return null; } - - var byRef = false; - if (typeName.StartsWith("out ") || typeName.StartsWith("ref ")) - { - typeName = typeName.Remove(0, 4); - byRef = true; - } - - var type = Type.GetType(typeName, throwOnError, ignoreCase); - if (type != null) { return byRef ? type.MakeByRefType() : type; } - foreach (var a in AppDomain.CurrentDomain.GetAssemblies()) - { - if (CsScriptBase.LoadedAssemblyName.Contains(a.GetName().Name)) - { - var attrs = a.GetCustomAttributes(); - var revision = attrs.FirstOrDefault(attr => attr.Key == "Revision")?.Value; - if (revision != null && int.Parse(revision) != (int)CsScriptBase.Revision[a.GetName().Name]) { continue; } - } - type = a.GetType(typeName, throwOnError, ignoreCase); - if (type != null) - { - return byRef ? type.MakeByRefType() : type; - } - } - return null; + return AssemblyManager.GetTypesByName(typeName).FirstOrDefault((Type)null); } public void ToggleDebugger(int port = 41912) @@ -293,17 +274,21 @@ public void Update() public void Stop() { - foreach (var type in AppDomain.CurrentDomain.GetAssemblies().Where(a => a.GetName().Name == CsScriptBase.CsScriptAssembly).SelectMany(assembly => assembly.GetTypes())) + + // unregister types + foreach (Type type in AssemblyManager.GetAllLoadedACLs().SelectMany( + acl => acl.AssembliesTypes.Select(kvp => kvp.Value))) { UserData.UnregisterType(type, true); } + + PluginPackageManager.UnloadPlugins(); // stop plugin code execution - foreach (var mod in ACsMod.LoadedMods.ToArray()) + if (Lua?.Globals is not null) { - mod.Dispose(); + Lua.Globals.Remove("CsPackageManager"); + Lua.Globals.Remove("AssemblyManager"); } - - ACsMod.LoadedMods.Clear(); if (Thread.CurrentThread == GameMain.MainThread) { @@ -317,27 +302,27 @@ public void Stop() Game?.Stop(); - Hook.Clear(); + Hook?.Clear(); ModStore.Clear(); + LuaScriptLoader = null; + Lua = null; + + // we can only unload assemblies after clearing ModStore/references. + PluginPackageManager.Dispose(); + Game = new LuaGame(); Networking = new LuaCsNetworking(); Timer = new LuaCsTimer(); Steam = new LuaCsSteam(); PerformanceCounter = new LuaCsPerformanceCounter(); - LuaScriptLoader = null; - Lua = null; - if (CsScriptLoader != null) - { - CsScriptLoader.Clear(); - CsScriptLoader.Unload(); - CsScriptLoader = null; - } + IsInitialized = false; } public void Initialize(bool forceEnableCs = false) { - Stop(); + if (IsInitialized) + Stop(); LuaCsLogger.LogMessage("Lua! Version " + AssemblyInfo.GitRevision); @@ -380,6 +365,9 @@ public void Initialize(bool forceEnableCs = false) UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); + UserData.RegisterType(); + UserData.RegisterType(); + UserData.RegisterType(); UserData.RegisterExtensionType(typeof(MathUtils)); UserData.RegisterExtensionType(typeof(XMLExtensions)); @@ -430,65 +418,80 @@ public void Initialize(bool forceEnableCs = false) DebugConsole.AddWarning("Cs package active! Cs mods are NOT sandboxed, use it at your own risk!"); } - CsScriptLoader = new CsScriptLoader(); - CsScriptLoader.SearchFolders(); - if (CsScriptLoader.HasSources) + Lua.Globals["PluginPackageManager"] = PluginPackageManager; + Lua.Globals["AssemblyManager"] = AssemblyManager; + + try { - try + Stopwatch taskTimer = new(); + taskTimer.Start(); + ModStore.Clear(); + + var state = PluginPackageManager.LoadAssemblyPackages(); + if (state is AssemblyLoadingSuccessState.Success or AssemblyLoadingSuccessState.AlreadyLoaded) { - Stopwatch compilationTime = new Stopwatch(); - compilationTime.Start(); - var modTypes = CsScriptLoader.Compile(); - - modTypes.ForEach(t => - { - t.GetConstructor(new Type[] { })?.Invoke(null); - }); - - compilationTime.Stop(); - LuaCsLogger.LogMessage($"Took {compilationTime.ElapsedMilliseconds}ms to compile and run Cs Scripts."); + if(!PluginPackageManager.PluginsInitialized) + PluginPackageManager.InstantiatePlugins(true); + if(!PluginPackageManager.PluginsPreInit) + PluginPackageManager.RunPluginsPreInit(); // this is intended to be called at startup in the future + if(!PluginPackageManager.PluginsLoaded) + PluginPackageManager.RunPluginsInit(); + state = AssemblyLoadingSuccessState.Success; + taskTimer.Stop(); + ModUtils.Logging.PrintMessage($"{nameof(LuaCsSetup)}: Completed assembly loading. Total time {taskTimer.ElapsedMilliseconds}ms."); + } + else + { + PluginPackageManager.Dispose(); // cleanup if there's an error } - catch (Exception ex) + + if(state is not AssemblyLoadingSuccessState.Success) { - LuaCsLogger.HandleException(ex, LuaCsMessageOrigin.CSharpMod); + ModUtils.Logging.PrintError($"{nameof(LuaCsSetup)}: Error while loading Cs-Assembly Mods | Err: {state}"); + taskTimer.Stop(); } } + catch (Exception e) + { + ModUtils.Logging.PrintError($"{nameof(LuaCsSetup)}::{nameof(Initialize)}() | Error while loading assemblies! Details: {e.Message} | {e.StackTrace}"); + } + IsInitialized = true; } ContentPackage luaPackage = GetPackage(LuaForBarotraumaId); - void runLocal() + void RunLocal() { LuaCsLogger.LogMessage("Using LuaSetup.lua from the Barotrauma Lua/ folder."); string luaPath = LuaSetupFile; CallLuaFunction(Lua.LoadFile(luaPath), Path.GetDirectoryName(Path.GetFullPath(luaPath))); } - void runWorkshop() + void RunWorkshop() { LuaCsLogger.LogMessage("Using LuaSetup.lua from the content package."); string luaPath = Path.Combine(Path.GetDirectoryName(luaPackage.Path), "Binary/Lua/LuaSetup.lua"); CallLuaFunction(Lua.LoadFile(luaPath), Path.GetDirectoryName(Path.GetFullPath(luaPath))); } - void runNone() + void RunNone() { LuaCsLogger.LogError("LuaSetup.lua not found! Lua/LuaSetup.lua, no Lua scripts will be executed or work.", LuaCsMessageOrigin.LuaMod); } if (Config.PreferToUseWorkshopLuaSetup) { - if (luaPackage != null) { runWorkshop(); } - else if (File.Exists(LuaSetupFile)) { runLocal(); } - else { runNone(); } + if (luaPackage != null) { RunWorkshop(); } + else if (File.Exists(LuaSetupFile)) { RunLocal(); } + else { RunNone(); } } else { - if (File.Exists(LuaSetupFile)) { runLocal(); } - else if (luaPackage != null) { runWorkshop(); } - else { runNone(); } + if (File.Exists(LuaSetupFile)) { RunLocal(); } + else if (luaPackage != null) { RunWorkshop(); } + else { RunNone(); } } executionNumber++; diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs index 8501004cd2..14a7685ef4 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs @@ -4,6 +4,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Collections.Immutable; using System.Diagnostics; using System.IO; using System.Linq; @@ -259,14 +260,22 @@ private enum ValueType private static Type[] LoadDocTypes(XElement typesElem) { var result = new List(); + var loadedTypes = LuaCsSetup.AssemblyManager + .GetAllTypesInLoadedAssemblies() + .ToImmutableHashSet(); + foreach (var elem in typesElem.Elements()) { - var type = Type.GetType(elem.Value); - if (type == null && GameMain.LuaCs?.CsScriptLoader?.Assembly != null) type = GameMain.LuaCs.CsScriptLoader.Assembly.GetType(elem.Value); - if (type == null) throw new Exception($"Type {elem.Value} not found."); - result.Add(type); - + var typesFound = loadedTypes.Where(t => t.FullName?.EndsWith(elem.Value) ?? false).ToImmutableList(); + if (!typesFound.Any()) + { + ModUtils.Logging.PrintError( + $"{nameof(LuaCsConfig)}::{nameof(LoadDocTypes)}() | Unable to find a matching type for {elem.Value}"); + continue; + } + result.AddRange(typesFound); } + return result.ToArray(); } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/ModUtils.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/ModUtils.cs new file mode 100644 index 0000000000..962f93371a --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/ModUtils.cs @@ -0,0 +1,331 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using System.Xml.Serialization; +using Barotrauma; +using Barotrauma.Items.Components; +using Barotrauma.Networking; +using Microsoft.CodeAnalysis; + +namespace Barotrauma; + +public static class ModUtils +{ + #region LOGGING + + public static class Logging + { + public static void PrintMessage(string s) + { +#if SERVER + LuaCsLogger.LogMessage($"[Server] {s}"); +#else + LuaCsLogger.LogMessage($"[Client] {s}"); +#endif + } + + public static void PrintError(string s) + { +#if SERVER + LuaCsLogger.LogError($"[Server] {s}"); +#else + LuaCsLogger.LogError($"[Client] {s}"); +#endif + } + } + + #endregion + + #region FILE_IO + + // ReSharper disable once InconsistentNaming + public static class IO + { + public static IEnumerable FindAllFilesInDirectory(string folder, string pattern, + SearchOption option) + { + try + { + return Directory.GetFiles(folder, pattern, option); + } + catch (DirectoryNotFoundException e) + { + return new string[] { }; + } + } + + public static string PrepareFilePathString(string filePath) => + PrepareFilePathString(Path.GetDirectoryName(filePath)!, Path.GetFileName(filePath)); + + public static string PrepareFilePathString(string path, string fileName) => + Path.Combine(SanitizePath(path), SanitizeFileName(fileName)); + + public static string SanitizeFileName(string fileName) + { + foreach (char c in Barotrauma.IO.Path.GetInvalidFileNameCharsCrossPlatform()) + fileName = fileName.Replace(c, '_'); + return fileName; + } + + /// + /// Gets the sanitized path for the top-level directory for a given content package. + /// + /// + /// + public static string GetContentPackageDir(ContentPackage package) + { + return SanitizePath(Path.GetFullPath(package.Dir)); + } + + public static string SanitizePath(string path) + { + foreach (char c in Path.GetInvalidPathChars()) + path = path.Replace(c.ToString(), "_"); + return path.CleanUpPath(); + } + + public static IOActionResultState GetOrCreateFileText(string filePath, out string fileText, Func fileDataFactory = null, bool createFile = true) + { + fileText = null; + string fp = Path.GetFullPath(SanitizePath(filePath)); + + IOActionResultState ioActionResultState = IOActionResultState.Success; + if (createFile) + { + ioActionResultState = CreateFilePath(SanitizePath(filePath), out fp, fileDataFactory); + } + else if (!File.Exists(fp)) + { + return IOActionResultState.FileNotFound; + } + + if (ioActionResultState == IOActionResultState.Success) + { + try + { + fileText = File.ReadAllText(fp!); + return IOActionResultState.Success; + } + catch (ArgumentNullException ane) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: An argument is null. path: {fp ?? "null"} | Exception Details: {ane.Message}"); + return IOActionResultState.FilePathNull; + } + catch (ArgumentException ae) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: An argument is invalid. path: {fp ?? "null"} | Exception Details: {ae.Message}"); + return IOActionResultState.FilePathInvalid; + } + catch (DirectoryNotFoundException dnfe) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Cannot find directory. path: {fp ?? "null"} | Exception Details: {dnfe.Message}"); + return IOActionResultState.DirectoryMissing; + } + catch (PathTooLongException ptle) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: path length is over 200 characters. path: {fp ?? "null"} | Exception Details: {ptle.Message}"); + return IOActionResultState.PathTooLong; + } + catch (NotSupportedException nse) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Operation not supported on your platform/environment (permissions?). path: {fp ?? "null"} | Exception Details: {nse.Message}"); + return IOActionResultState.InvalidOperation; + } + catch (IOException ioe) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: IO tasks failed (Operation not supported). path: {fp ?? "null"} | Exception Details: {ioe.Message}"); + return IOActionResultState.IOFailure; + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Unknown/Other Exception. path: {fp ?? "null"} | ExceptionMessage: {e.Message}"); + return IOActionResultState.UnknownError; + } + } + + return ioActionResultState; + } + + public static IOActionResultState CreateFilePath(string filePath, out string formattedFilePath, Func fileDataFactory = null) + { + string file = Path.GetFileName(filePath); + string path = Path.GetDirectoryName(filePath)!; + + formattedFilePath = IO.PrepareFilePathString(path, file); + try + { + if (!Directory.Exists(path)) + Directory.CreateDirectory(path); + if (!File.Exists(formattedFilePath)) + File.WriteAllText(formattedFilePath, fileDataFactory is null ? "" : fileDataFactory.Invoke()); + return IOActionResultState.Success; + } + catch (ArgumentNullException ane) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: An argument is null. path: {formattedFilePath ?? "null"} | Exception Details: {ane.Message}"); + return IOActionResultState.FilePathNull; + } + catch (ArgumentException ae) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: An argument is invalid. path: {formattedFilePath ?? "null"} | Exception Details: {ae.Message}"); + return IOActionResultState.FilePathInvalid; + } + catch (DirectoryNotFoundException dnfe) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Cannot find directory. path: {path ?? "null"} | Exception Details: {dnfe.Message}"); + return IOActionResultState.DirectoryMissing; + } + catch (PathTooLongException ptle) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: path length is over 200 characters. path: {formattedFilePath ?? "null"} | Exception Details: {ptle.Message}"); + return IOActionResultState.PathTooLong; + } + catch (NotSupportedException nse) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Operation not supported on your platform/environment (permissions?). path: {formattedFilePath ?? "null"} | Exception Details: {nse.Message}"); + return IOActionResultState.InvalidOperation; + } + catch (IOException ioe) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: IO tasks failed (Operation not supported). path: {formattedFilePath ?? "null"} | Exception Details: {ioe.Message}"); + return IOActionResultState.IOFailure; + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Unknown/Other Exception. path: {path ?? "null"} | Exception Details: {e.Message}"); + return IOActionResultState.UnknownError; + } + } + + public static IOActionResultState WriteFileText(string filePath, string fileText) + { + IOActionResultState ioActionResultState = CreateFilePath(filePath, out var fp); + if (ioActionResultState == IOActionResultState.Success) + { + try + { + File.WriteAllText(fp!, fileText); + return IOActionResultState.Success; + } + catch (ArgumentNullException ane) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: An argument is null. path: {fp ?? "null"} | Exception Details: {ane.Message}"); + return IOActionResultState.FilePathNull; + } + catch (ArgumentException ae) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: An argument is invalid. path: {fp ?? "null"} | Exception Details: {ae.Message}"); + return IOActionResultState.FilePathInvalid; + } + catch (DirectoryNotFoundException dnfe) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: Cannot find directory. path: {fp ?? "null"} | Exception Details: {dnfe.Message}"); + return IOActionResultState.DirectoryMissing; + } + catch (PathTooLongException ptle) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: path length is over 200 characters. path: {fp ?? "null"} | Exception Details: {ptle.Message}"); + return IOActionResultState.PathTooLong; + } + catch (NotSupportedException nse) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: Operation not supported on your platform/environment (permissions?). path: {fp ?? "null"} | Exception Details: {nse.Message}"); + return IOActionResultState.InvalidOperation; + } + catch (IOException ioe) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: IO tasks failed (Operation not supported). path: {fp ?? "null"} | Exception Details: {ioe.Message}"); + return IOActionResultState.IOFailure; + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: Unknown/Other Exception. path: {fp ?? "null"} | ExceptionMessage: {e.Message}"); + return IOActionResultState.UnknownError; + } + } + + return ioActionResultState; + } + + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static bool LoadOrCreateTypeXml(out T instance, + string filepath, Func typeFactory = null, bool createFile = true) where T : class, new() + { + instance = null; + filepath = filepath.CleanUpPath(); + if (IOActionResultState.Success == GetOrCreateFileText( + filepath, out string fileText, typeFactory is not null ? () => + { + using StringWriter sw = new StringWriter(); + T t = typeFactory?.Invoke(); + if (t is not null) + { + XmlSerializer s = new XmlSerializer(typeof(T)); + s.Serialize(sw, t); + return sw.ToString(); + } + return ""; + } : null, createFile)) + { + XmlSerializer s = new XmlSerializer(typeof(T)); + try + { + using TextReader tr = new StringReader(fileText); + instance = (T)s.Deserialize(tr); + return true; + } + catch(InvalidOperationException ioe) + { + ModUtils.Logging.PrintError($"Error while parsing type data for {typeof(T)}."); + #if DEBUG + ModUtils.Logging.PrintError($"Exception: {ioe.Message}. Details: {ioe.InnerException?.Message}"); + #endif + instance = null; + return false; + } + } + + return false; + } + + public enum IOActionResultState + { + Success, FileNotFound, FilePathNull, FilePathInvalid, DirectoryMissing, PathTooLong, InvalidOperation, IOFailure, UnknownError + } + } + + #endregion + + #region GAME + + public static class Game + { + /// + /// Returns whether or not there is a round running. + /// + /// + public static bool IsRoundInProgress() + { +#if CLIENT + if (Screen.Selected is not null + && Screen.Selected.IsEditor) + return false; +#endif + return GameMain.GameSession is not null && Level.Loaded is not null; + } + + } + + #endregion +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/ACsMod.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ACsMod.cs similarity index 60% rename from Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/ACsMod.cs rename to Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ACsMod.cs index 60ef6551d5..76dfac73f2 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/ACsMod.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ACsMod.cs @@ -1,11 +1,11 @@ using System; using System.Collections.Generic; using System.IO; -using System.Reflection; namespace Barotrauma { - public abstract class ACsMod : IDisposable + [Obsolete("Make your class implement IAssemblyPlugin instead.")] + public abstract class ACsMod : IAssemblyPlugin { private static List mods = new List(); public static List LoadedMods { get => mods; } @@ -18,7 +18,6 @@ public static string GetStoreFolder() where T : ACsMod if (!Directory.Exists(modFolder)) Directory.CreateDirectory(modFolder); return modFolder; } - public static string GetSoreFolder() where T : ACsMod => GetStoreFolder(); public bool IsDisposed { get; private set; } @@ -29,7 +28,23 @@ public ACsMod() LoadedMods.Add(this); } - public void Dispose() + /// + /// Called as soon as plugin loading begins, use this for internal setup only. + /// + public virtual void Initialize() { } + + /// + /// Called once all plugins have completed Initialization. Put cross-mod code here. + /// + public virtual void OnLoadCompleted() { } + + /// + /// [NotImplemented] Called before vanilla content is loaded. Use to patch Barotrauma classes before they're + /// instantiated. + /// + public void PreInitPatching() { } + + public virtual void Dispose() { try { @@ -43,8 +58,7 @@ public void Dispose() LoadedMods.Remove(this); IsDisposed = true; } - - /// Error or client exit + public abstract void Stop(); } } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ApplicationMode.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ApplicationMode.cs new file mode 100644 index 0000000000..6e60184bb7 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ApplicationMode.cs @@ -0,0 +1,6 @@ +namespace Barotrauma; + +public enum ApplicationMode +{ + Client, Server +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyLoadingSuccessState.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyLoadingSuccessState.cs new file mode 100644 index 0000000000..e55821eb3d --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyLoadingSuccessState.cs @@ -0,0 +1,15 @@ +namespace Barotrauma; + +public enum AssemblyLoadingSuccessState +{ + ACLLoadFailure, + AlreadyLoaded, + BadFilePath, + CannotLoadFile, + InvalidAssembly, + NoAssemblyFound, + PluginInstanceFailure, + BadName, + CannotLoadFromStream, + Success +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyManager.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyManager.cs new file mode 100644 index 0000000000..64f12bf684 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyManager.cs @@ -0,0 +1,772 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.Loader; +using System.Threading; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; + +// ReSharper disable EventNeverSubscribedTo.Global +// ReSharper disable InconsistentNaming + +namespace Barotrauma; + +/*** + * Note: This class was written to be thread-safe in order to allow parallelization in loading in the future if the need + * becomes necessary as there is almost no serial performance overhead for adding threading protection. + */ + +/// +/// Provides functionality for the loading, unloading and management of plugins implementing IAssemblyPlugin. +/// All plugins are loaded into their own AssemblyLoadContext along with their dependencies. +/// +public partial class AssemblyManager +{ + #region ExternalAPI + + /// + /// Called when an assembly is loaded. + /// + public event Action OnAssemblyLoaded; + + /// + /// Called when an assembly is marked for unloading, before unloading begins. You should use this to cleanup + /// any references that you have to this assembly. + /// + public event Action OnAssemblyUnloading; + + /// + /// Called whenever an exception is thrown. First arg is a formatted message, Second arg is the Exception. + /// + public event Action OnException; + + /// + /// For unloading issue debugging. Called whenever MemoryFileAssemblyContextLoader [load context] is unloaded. + /// + public event Action OnACLUnload; + + #if DEBUG + + /// + /// [DEBUG ONLY] + /// Returns a list of the current unloading ACLs. + /// + public ImmutableList> StillUnloadingACLs + { + get + { + OpsLockUnloaded.EnterReadLock(); + try + { + return UnloadingACLs.ToImmutableList(); + } + finally + { + OpsLockUnloaded.ExitReadLock(); + } + } + } + + #endif + + + // ReSharper disable once MemberCanBePrivate.Global + /// + /// Checks if there are any AssemblyLoadContexts still in the process of unloading. + /// + public bool IsCurrentlyUnloading + { + get + { + OpsLockUnloaded.EnterReadLock(); + try + { + return UnloadingACLs.Any(); + } + catch (Exception) + { + return false; + } + finally + { + OpsLockUnloaded.ExitReadLock(); + } + } + } + + // Old API compatibility + public IEnumerable GetSubTypesInLoadedAssemblies() + { + return GetSubTypesInLoadedAssemblies(false); + } + + + /// + /// Allows iteration over all non-interface types in all loaded assemblies in the AsmMgr that are assignable to the given type (IsAssignableFrom). + /// Warning: care should be used when using this method in hot paths as performance may be affected. + /// + /// The type to compare against + /// Forces caches to clear and for the lists of types to be rebuilt. + /// An Enumerator for matching types. + public IEnumerable GetSubTypesInLoadedAssemblies(bool rebuildList) + { + Type targetType = typeof(T); + string typeName = targetType.FullName ?? targetType.Name; + + // rebuild + if (rebuildList) + RebuildTypesList(); + + // check cache + if (_subTypesLookupCache.TryGetValue(typeName, out var subTypeList)) + { + return subTypeList; + } + + // build from scratch + OpsLockLoaded.EnterReadLock(); + try + { + // build list + var list1 = _defaultContextTypes + .Where(kvp1 => targetType.IsAssignableFrom(kvp1.Value) && !kvp1.Value.IsInterface) + .Concat(LoadedACLs + .SelectMany(kvp => kvp.Value.AssembliesTypes) + .Where(kvp2 => targetType.IsAssignableFrom(kvp2.Value) && !kvp2.Value.IsInterface)) + .Select(kvp3 => kvp3.Value) + .ToImmutableList(); + + // only add if we find something + if (list1.Count > 0) + { + if (!_subTypesLookupCache.TryAdd(typeName, list1)) + { + ModUtils.Logging.PrintError($"{nameof(AssemblyManager)}: Unable to add subtypes to cache of type {typeName}!"); + } + } + else + { + ModUtils.Logging.PrintMessage($"{nameof(AssemblyManager)}: Warning: No types found during search for subtypes of {typeName}"); + } + + return list1; + } + finally + { + OpsLockLoaded.ExitReadLock(); + } + } + + /// + /// Tries to get types assignable to type from the ACL given the Guid. + /// + /// + /// + /// + /// + public bool TryGetSubTypesFromACL(Guid id, out IEnumerable types) + { + Type targetType = typeof(T); + + if (TryGetACL(id, out var acl)) + { + types = acl.AssembliesTypes + .Where(kvp => targetType.IsAssignableFrom(kvp.Value) && !kvp.Value.IsInterface) + .Select(kvp => kvp.Value); + return true; + } + + types = null; + return false; + } + + /// + /// Tries to get types from the ACL given the Guid. + /// + /// + /// + /// + /// + public bool TryGetSubTypesFromACL(Guid id, out IEnumerable types) + { + if (TryGetACL(id, out var acl)) + { + types = acl.AssembliesTypes.Select(kvp => kvp.Value); + return true; + } + + types = null; + return false; + } + + + /// + /// Allows iteration over all types, including interfaces, in all loaded assemblies in the AsmMgr who's names match the string. + /// Note: Will return the by-reference equivalent type if the type name is prefixed with "out " or "ref ". + /// + /// The string name of the type to search for. + /// An Enumerator for matching types. + public IEnumerable GetTypesByName(string typeName) + { + bool byRef = false; + if (typeName.StartsWith("out ") || typeName.StartsWith("ref ")) + { + typeName = typeName.Remove(0, 4); + byRef = true; + } + + List types = new(); + + TypesListHelper(); + if (types.Count > 0) + return types; + + // we couldn't find it, rebuild and try one more time + RebuildTypesList(); + TypesListHelper(); + return types; + + void TypesListHelper() + { + if (_defaultContextTypes.TryGetValue(typeName, out var type1)) + { + if (type1 is not null) + types.Add(byRef ? type1.MakeByRefType() : type1); + } + + OpsLockLoaded.EnterReadLock(); + try + { + foreach (KeyValuePair loadedAcl in LoadedACLs) + { + var at = loadedAcl.Value.AssembliesTypes; + if (at.TryGetValue(typeName, out var type2)) + { + if (type2 is not null) + types.Add(byRef ? type2.MakeByRefType() : type2); + } + } + } + finally + { + OpsLockLoaded.ExitReadLock(); + } + } + } + + /// + /// Allows iteration over all types (including interfaces) in all loaded assemblies managed by the AsmMgr. + /// Warning: High usage may result in performance issues. + /// + /// An Enumerator for iteration. + public IEnumerable GetAllTypesInLoadedAssemblies() + { + OpsLockLoaded.EnterReadLock(); + try + { + return AssemblyLoadContext.Default.Assemblies + .SelectMany(a => a.GetSafeTypes()) + .Concat(LoadedACLs + .SelectMany(kvp => kvp.Value.AssembliesTypes.Select(kv => kv.Value))) + .ToImmutableList(); + } + finally + { + OpsLockLoaded.ExitReadLock(); + } + } + + /// + /// Returns a list of all loaded ACLs. + /// WARNING: References to these ACLs outside of the AssemblyManager should be kept in a WeakReference in order + /// to avoid causing issues with unloading/disposal. + /// + /// + public IEnumerable GetAllLoadedACLs() + { + try + { + OpsLockLoaded.EnterReadLock(); + return LoadedACLs.Select(kvp => kvp.Value).ToImmutableList(); + } + finally + { + OpsLockLoaded.ExitReadLock(); + } + + } + + #endregion + + #region InternalAPI + + /// + /// Used by content package and plugin management to stop unloading of a given ACL until all plugins have gracefully closed. + /// + public event System.Func IsReadyToUnloadACL; + + public AssemblyLoadingSuccessState LoadAssemblyFromMemory([NotNull] string compiledAssemblyName, + [NotNull] IEnumerable syntaxTree, + IEnumerable externalMetadataReferences, + [NotNull] CSharpCompilationOptions compilationOptions, + ref Guid id, + IEnumerable externFileAssemblyRefs = null) + { + // validation + if (compiledAssemblyName.IsNullOrWhiteSpace()) + return AssemblyLoadingSuccessState.BadName; + + if (!GetOrCreateACL(id, out var acl)) + return AssemblyLoadingSuccessState.ACLLoadFailure; + + id = acl.Id; // pass on true id returned + + // this acl is already hosting an in-memory assembly + if (acl.Acl.CompiledAssembly is not null) + return AssemblyLoadingSuccessState.AlreadyLoaded; + + // compile + var state = acl.Acl.CompileAndLoadScriptAssembly(compiledAssemblyName, syntaxTree, externalMetadataReferences, + compilationOptions, out var messages, externFileAssemblyRefs); + + // get types + if (state is AssemblyLoadingSuccessState.Success) + { + _subTypesLookupCache.Clear(); + acl.RebuildTypesList(); + OnAssemblyLoaded?.Invoke(acl.Acl.CompiledAssembly); + } + else + { + ModUtils.Logging.PrintError($"Unable to compile assembly '{compiledAssemblyName}' due to errors: {messages}"); + } + + return state; + } + + /// + /// Switches the ACL with the given Guid to Template Mode, which disables assembly name resolution for any assemblies loaded in it. + /// These ACLs are intended to be used to host Assemblies for information only and not for code execution. + /// WARNING: This process is irreversible. + /// + /// Guid of the ACL. + /// Whether or not an ACL was found with the given ID. + public bool SetACLToTemplateMode(Guid guid) + { + if (!TryGetACL(guid, out var acl)) + return false; + acl.Acl.IsTemplateMode = true; + return true; + } + + /// + /// Tries to load all assemblies at the supplied file paths list into the ACl with the given Guid. + /// If the supplied Guid is Empty, then a new ACl will be created and the Guid will be assigned to it. + /// + /// List of assemblies to try and load. + /// Guid of the ACL or Empty if none specified. Guid of ACL will be assigned to this var. + /// Operation success messages. + /// + public AssemblyLoadingSuccessState LoadAssembliesFromLocations([NotNull] IEnumerable filePaths, + ref Guid id) + { + + if (filePaths is null) + { + throw new ArgumentNullException( + $"{nameof(AssemblyManager)}::{nameof(LoadAssembliesFromLocations)}() | file paths supplied is null!"); + } + + ImmutableList assemblyFilePaths = filePaths.ToImmutableList(); // copy the list before loading + + if (!assemblyFilePaths.Any()) + { + return AssemblyLoadingSuccessState.NoAssemblyFound; + } + + if (GetOrCreateACL(id, out var loadedAcl)) + { + var state = loadedAcl.Acl.LoadFromFiles(assemblyFilePaths); + // if failure, we dispose of the acl + if (state != AssemblyLoadingSuccessState.Success) + { + DisposeACL(loadedAcl.Id); + ModUtils.Logging.PrintError($"ACL failed, unloading..."); + return state; + } + // build types list + _subTypesLookupCache.Clear(); + loadedAcl.RebuildTypesList(); + id = loadedAcl.Id; + foreach (Assembly assembly in loadedAcl.Acl.Assemblies) + { + OnAssemblyLoaded?.Invoke(assembly); + } + return state; + } + + return AssemblyLoadingSuccessState.ACLLoadFailure; + } + + + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Synchronized)] + public bool TryBeginDispose() + { + OpsLockLoaded.EnterWriteLock(); + OpsLockUnloaded.EnterWriteLock(); + try + { + _subTypesLookupCache.Clear(); + + foreach (KeyValuePair loadedAcl in LoadedACLs) + { + if (loadedAcl.Value.Acl is not null) + { + foreach (Delegate del in IsReadyToUnloadACL.GetInvocationList()) + { + if (del is System.Func { } func) + { + if (!func.Invoke(loadedAcl.Value)) + return false; // Not ready, exit + } + } + + foreach (Assembly assembly in loadedAcl.Value.Acl.Assemblies) + { + OnAssemblyUnloading?.Invoke(assembly); + } + + UnloadingACLs.Add(new WeakReference(loadedAcl.Value.Acl, true)); + loadedAcl.Value.ClearTypesList(); + loadedAcl.Value.Acl.Unload(); + OnACLUnload?.Invoke(loadedAcl.Value.Id); + } + } + + LoadedACLs.Clear(); + return true; + } + catch + { + // should never happen + return false; + } + finally + { + OpsLockUnloaded.ExitWriteLock(); + OpsLockLoaded.ExitWriteLock(); + } + } + + + [MethodImpl(MethodImplOptions.NoInlining)] + public bool FinalizeDispose() + { + bool isUnloaded; + OpsLockUnloaded.EnterUpgradeableReadLock(); + try + { + List> toRemove = new(); + foreach (WeakReference weakReference in UnloadingACLs) + { + if (!weakReference.TryGetTarget(out _)) + { + toRemove.Add(weakReference); + } + } + + if (toRemove.Any()) + { + OpsLockUnloaded.EnterWriteLock(); + try + { + foreach (WeakReference reference in toRemove) + { + UnloadingACLs.Remove(reference); + } + } + finally + { + OpsLockUnloaded.ExitWriteLock(); + } + } + isUnloaded = !UnloadingACLs.Any(); + } + finally + { + OpsLockUnloaded.ExitUpgradeableReadLock(); + } + + return isUnloaded; + } + + /// + /// Tries to retrieve the LoadedACL with the given ID or null if none is found. + /// WARNING: External references to this ACL with long lifespans should be kept in a WeakReference + /// to avoid causing unloading/disposal issues. + /// + /// GUID of the ACL. + /// The found ACL or null if none was found. + /// Whether or not an ACL was found. + [MethodImpl(MethodImplOptions.NoInlining)] + public bool TryGetACL(Guid id, out LoadedACL acl) + { + acl = null; + OpsLockLoaded.EnterReadLock(); + try + { + if (id.Equals(Guid.Empty) || !LoadedACLs.ContainsKey(id)) + return false; + acl = LoadedACLs[id]; + return true; + } + finally + { + OpsLockLoaded.ExitReadLock(); + } + } + + + /// + /// Gets or creates an AssemblyCtxLoader for the given ID. Creates if the ID is empty or no ACL can be found. + /// [IMPORTANT] After calling this method, the id you use should be taken from the acl container (acl.Id). + /// + /// + /// + /// Should only return false if an error occurs. + [MethodImpl(MethodImplOptions.NoInlining)] + private bool GetOrCreateACL(Guid id, out LoadedACL acl) + { + OpsLockLoaded.EnterUpgradeableReadLock(); + try + { + if (id.Equals(Guid.Empty) || !LoadedACLs.ContainsKey(id) || LoadedACLs[id] is null) + { + OpsLockLoaded.EnterWriteLock(); + try + { + id = Guid.NewGuid(); + acl = new LoadedACL(id, this); + LoadedACLs[id] = acl; + return true; + } + finally + { + OpsLockLoaded.ExitWriteLock(); + } + } + else + { + acl = LoadedACLs[id]; + return true; + } + + } + catch + { + // should never happen but in-case + acl = null; + return false; + } + finally + { + OpsLockLoaded.ExitUpgradeableReadLock(); + } + } + + + [MethodImpl(MethodImplOptions.NoInlining)] + private bool DisposeACL(Guid id) + { + OpsLockLoaded.EnterWriteLock(); + OpsLockUnloaded.EnterWriteLock(); + try + { + if (id.Equals(Guid.Empty) || !LoadedACLs.ContainsKey(id) || LoadedACLs[id] is null) + { + return false; // nothing to dispose of + } + + var acl = LoadedACLs[id]; + + foreach (Assembly assembly in acl.Acl.Assemblies) + { + OnAssemblyUnloading?.Invoke(assembly); + } + + _subTypesLookupCache.Clear(); + UnloadingACLs.Add(new WeakReference(acl.Acl, true)); + acl.Acl.Unload(); + OnACLUnload?.Invoke(acl.Id); + + return true; + } + catch + { + // should never happen + return false; + } + finally + { + OpsLockLoaded.ExitWriteLock(); + OpsLockUnloaded.ExitWriteLock(); + } + } + + internal AssemblyManager() + { + RebuildTypesList(); + } + + /// + /// Rebuilds the list of types in the default assembly load context. + /// + private void RebuildTypesList() + { + try + { + _defaultContextTypes = AssemblyLoadContext.Default.Assemblies + .SelectMany(a => a.GetSafeTypes()) + .ToImmutableDictionary(t => t.FullName ?? t.Name, t => t); + _subTypesLookupCache.Clear(); + } + catch(ArgumentException _) + { + try + { + // some types must've had duplicate type names, build the list while filtering + Dictionary types = new(); + foreach (var type in AssemblyLoadContext.Default.Assemblies.SelectMany(a => a.GetSafeTypes())) + { + try + { + types.TryAdd(type.FullName ?? type.Name, type); + } + catch + { + // ignore, null key exception + } + } + + _defaultContextTypes = types.ToImmutableDictionary(); + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"{nameof(AssemblyManager)}: Unable to create list of default assembly types! Default AssemblyLoadContext types searching not available."); +#if DEBUG + ModUtils.Logging.PrintError($"{nameof(AssemblyManager)}: Exception Details :{e.Message} | {e.InnerException}"); +#endif + _defaultContextTypes = ImmutableDictionary.Empty; + } + } + } + + #endregion + + #region Data + + private readonly ConcurrentDictionary> _subTypesLookupCache = new(); + private ImmutableDictionary _defaultContextTypes; + private readonly ConcurrentDictionary LoadedACLs = new(); + private readonly List> UnloadingACLs= new(); + private readonly ReaderWriterLockSlim OpsLockLoaded = new ReaderWriterLockSlim(); + private readonly ReaderWriterLockSlim OpsLockUnloaded = new ReaderWriterLockSlim(); + + #endregion + + #region TypeDefs + + + public sealed class LoadedACL + { + public readonly Guid Id; + private ImmutableDictionary _assembliesTypes = ImmutableDictionary.Empty; + public readonly MemoryFileAssemblyContextLoader Acl; + private readonly AssemblyManager _manager; + + internal LoadedACL(Guid id, AssemblyManager manager) + { + this.Id = id; + this.Acl = new(manager); + this._manager = manager; + } + public ImmutableDictionary AssembliesTypes => _assembliesTypes; + + /// + /// Rebuild the list of types from assemblies loaded in the AsmCtxLoader. + /// + internal void RebuildTypesList() + { + ClearTypesList(); + try + { + _assembliesTypes = this.Acl.Assemblies + .SelectMany(a => a.GetSafeTypes()) + .ToImmutableDictionary(t => t.FullName ?? t.Name, t => t); + } + catch(ArgumentException _) + { + // some types must've had duplicate type names, build the list while filtering + Dictionary types = new(); + foreach (var type in this.Acl.Assemblies.SelectMany(a => a.GetSafeTypes())) + { + try + { + types.TryAdd(type.FullName ?? type.Name, type); + } + catch + { + // ignore, null key exception + } + } + + _assembliesTypes = types.ToImmutableDictionary(); + } + } + + internal void ClearTypesList() + { + _assembliesTypes.Clear(); + } + } + + #endregion +} + +public static class AssemblyExtensions +{ + /// + /// Gets all types in the given assembly. Handles invalid type scenarios. + /// + /// The assembly to scan + /// An enumerable collection of types. + public static IEnumerable GetSafeTypes(this Assembly assembly) + { + // Based on https://github.com/Qkrisi/ktanemodkit/blob/master/Assets/Scripts/ReflectionHelper.cs#L53-L67 + + try + { + return assembly.GetTypes(); + } + catch (ReflectionTypeLoadException re) + { + try + { + return re.Types.Where(x => x != null)!; + } + catch (InvalidOperationException ioe) + { + return new List(); + } + } + catch (Exception e) + { + return new List(); + } + } +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/CsPackageManager.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/CsPackageManager.cs new file mode 100644 index 0000000000..78084237f8 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/CsPackageManager.cs @@ -0,0 +1,978 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading; +using Barotrauma.Steam; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using MonoMod.Utils; + +namespace Barotrauma; + +public sealed class CsPackageManager : IDisposable +{ + #region PRIVATE_FUNCDATA + + private static readonly CSharpParseOptions ScriptParseOptions = CSharpParseOptions.Default + .WithPreprocessorSymbols(new[] + { +#if SERVER + "SERVER" +#elif CLIENT + "CLIENT" +#else + "UNDEFINED" +#endif +#if DEBUG + ,"DEBUG" +#endif + }); + +#if WINDOWS + private const string PLATFORM_TARGET = "Windows"; +#elif OSX + private const string PLATFORM_TARGET = "OSX"; +#elif LINUX + private const string PLATFORM_TARGET = "Linux"; +#endif + +#if CLIENT + private const string ARCHITECTURE_TARGET = "Client"; +#elif SERVER + private const string ARCHITECTURE_TARGET = "Server"; +#endif + + private static readonly CSharpCompilationOptions CompilationOptions = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary) + .WithMetadataImportOptions(MetadataImportOptions.All) +#if DEBUG + .WithOptimizationLevel(OptimizationLevel.Debug) +#else + .WithOptimizationLevel(OptimizationLevel.Release) +#endif + .WithAllowUnsafe(true); + + private static readonly SyntaxTree BaseAssemblyImports = CSharpSyntaxTree.ParseText( + new StringBuilder() + .AppendLine("using System.Reflection;") + .AppendLine("using Barotrauma;") + .AppendLine("using System.Runtime.CompilerServices;") +#if CLIENT + .AppendLine("[assembly: IgnoresAccessChecksTo(\"Barotrauma\")]") +#elif SERVER + .AppendLine("[assembly: IgnoresAccessChecksTo(\"DedicatedServer\")]") +#endif + .ToString(), + ScriptParseOptions); + + private const string SCRIPT_FILE_REGEX = "*.cs"; + private const string ASSEMBLY_FILE_REGEX = "*.dll"; + + private readonly float _assemblyUnloadTimeoutSeconds = 4f; + private Guid _publicizedAssemblyLoader; + private readonly List _currentPackagesByLoadOrder = new(); + private readonly Dictionary> _packagesDependencies = new(); + private readonly Dictionary _loadedCompiledPackageAssemblies = new(); + private readonly Dictionary _reverseLookupGuidList = new(); + private readonly Dictionary> _loadedPlugins = new (); + private readonly Dictionary> _pluginTypes = new(); // where Type : IAssemblyPlugin + private readonly Dictionary _packageRunConfigs = new(); + private readonly Dictionary> _luaRegisteredTypes = new(); + private readonly AssemblyManager _assemblyManager; + private readonly LuaCsSetup _luaCsSetup; + private DateTime _assemblyUnloadStartTime; + + + #endregion + + #region PUBLIC_API + + #region LUA_EXTENSIONS + + /// + /// Searches for all types in all loaded assemblies from content packages who's names contain the name string and registers them with the Lua Interpreter. + /// + /// + /// + /// + public bool LuaTryRegisterPackageTypes(string name, bool caseSensitive = false) + { + if (!AssembliesLoaded) + return false; + var matchingPacks = _loadedCompiledPackageAssemblies + .Where(kvp => kvp.Key.Name.ToLowerInvariant().Contains(name.ToLowerInvariant())) + .Select(kvp => kvp.Value) + .ToImmutableList(); + if (!matchingPacks.Any()) + return false; + var types = matchingPacks + .Where(guid => !_luaRegisteredTypes.ContainsKey(guid)) + .Select(guid => new KeyValuePair>( + guid, + _assemblyManager.TryGetSubTypesFromACL(guid, out var types) + ? types.ToImmutableList() + : ImmutableList.Empty)) + .ToImmutableList(); + if (!types.Any()) + return false; + foreach (var kvp in types) + { + _luaRegisteredTypes[kvp.Key] = kvp.Value; + foreach (Type type in kvp.Value) + { + MoonSharp.Interpreter.UserData.RegisterType(type); + } + } + + return true; + } + + #endregion + + /// + /// Whether or not assemblies have been loaded. + /// + public bool AssembliesLoaded { get; private set; } + + + /// + /// Whether or not loaded plugins had their preloader run. + /// + public bool PluginsPreInit { get; private set; } + + /// + /// Whether or not plugins' types have been instantiated. + /// + public bool PluginsInitialized { get; private set; } = false; + + /// + /// Whether or not plugins are fully loaded. + /// + public bool PluginsLoaded { get; private set; } = false; + + public IEnumerable GetCurrentPackagesByLoadOrder() => _currentPackagesByLoadOrder; + + /// + /// Tries to find the content package that a given plugin belongs to. + /// + /// Package if found, null otherwise. + /// The IAssemblyPlugin type to find. + /// + public bool TryGetPackageForPlugin(out ContentPackage package) where T : IAssemblyPlugin + { + package = null; + + var t = typeof(T); + var guid = _pluginTypes + .Where(kvp => kvp.Value.Contains(t)) + .Select(kvp => kvp.Key) + .FirstOrDefault(Guid.Empty); + + if (guid.Equals(Guid.Empty) || !_reverseLookupGuidList.ContainsKey(guid) || _reverseLookupGuidList[guid] is null) + return false; + package = _reverseLookupGuidList[guid]; + return true; + } + + + /// + /// Tries to get the loaded plugins for a given package. + /// + /// Package to find. + /// The collection of loaded plugins. + /// + public bool TryGetLoadedPluginsForPackage(ContentPackage package, out IEnumerable loadedPlugins) + { + loadedPlugins = null; + if (package is null || !_loadedCompiledPackageAssemblies.ContainsKey(package)) + return false; + var guid = _loadedCompiledPackageAssemblies[package]; + if (guid.Equals(Guid.Empty) || !_loadedPlugins.ContainsKey(guid)) + return false; + loadedPlugins = _loadedPlugins[guid]; + return true; + } + + /// + /// Called when clean up is being performed. Use when relying on or making use of references from this manager. + /// + public event Action OnDispose; + + public void Dispose() + { + // send events for cleanup + OnDispose?.Invoke(); + // cleanup events + if (OnDispose is not null) + { + foreach (Delegate del in OnDispose.GetInvocationList()) + { + OnDispose -= (del as System.Action); + } + } + + // cleanup plugins and assemblies + ReflectionUtils.ResetCache(); + UnloadPlugins(); + + // try cleaning up the assemblies + _pluginTypes.Clear(); // remove assembly references + _loadedPlugins.Clear(); + + // lua cleanup + foreach (var kvp in _luaRegisteredTypes) + { + foreach (Type type in kvp.Value) + { + MoonSharp.Interpreter.UserData.UnregisterType(type); + } + } + _luaRegisteredTypes.Clear(); + + _assemblyUnloadStartTime = DateTime.Now; + _publicizedAssemblyLoader = Guid.Empty; + + // we can't wait forever or app dies but we can try to be graceful + while (!_assemblyManager.TryBeginDispose()) + { + if (_assemblyUnloadStartTime.AddSeconds(_assemblyUnloadTimeoutSeconds) > DateTime.Now) + { + break; + } + } + + _assemblyUnloadStartTime = DateTime.Now; + while (!_assemblyManager.FinalizeDispose()) + { + if (_assemblyUnloadStartTime.AddSeconds(_assemblyUnloadTimeoutSeconds) > DateTime.Now) + { + break; + } + } + + _assemblyManager.OnAssemblyLoaded -= AssemblyManagerOnAssemblyLoaded; + _assemblyManager.OnAssemblyUnloading -= AssemblyManagerOnAssemblyUnloading; + + _publicizedAssemblyLoader = Guid.Empty; + + // clear lists after cleaning up + _packagesDependencies.Clear(); + _loadedCompiledPackageAssemblies.Clear(); + _reverseLookupGuidList.Clear(); + _packageRunConfigs.Clear(); + _currentPackagesByLoadOrder.Clear(); + + AssembliesLoaded = false; + GC.SuppressFinalize(this); + } + + /// + /// Begins the loading process of scanning packages for scripts and binary assemblies, compiling and executing them. + /// + /// + public AssemblyLoadingSuccessState LoadAssemblyPackages() + { + if (AssembliesLoaded) + { + return AssemblyLoadingSuccessState.AlreadyLoaded; + } + + _assemblyManager.OnAssemblyLoaded += AssemblyManagerOnAssemblyLoaded; + _assemblyManager.OnAssemblyUnloading += AssemblyManagerOnAssemblyUnloading; + + // load publicized assemblies + var publicizedDir = Path.Combine(Environment.CurrentDirectory, "Publicized"); + ImmutableList publicizedAssemblies = ImmutableList.Empty; + if (Directory.Exists(publicizedDir)) + { + // search for assemblies + var list = Directory.GetFiles(publicizedDir, "*.dll") +#if CLIENT + .Where(s => !s.ToLowerInvariant().EndsWith("dedicatedserver.dll")); +#elif SERVER + .Where(s => !s.ToLowerInvariant().EndsWith("barotrauma.dll")); +#endif + + // try load them into an acl + var loadState = _assemblyManager.LoadAssembliesFromLocations(list, ref _publicizedAssemblyLoader); + + // loaded + if (loadState is AssemblyLoadingSuccessState.Success) + { + if (_assemblyManager.TryGetACL(_publicizedAssemblyLoader, out var acl)) + { + publicizedAssemblies = acl.Acl.Assemblies.ToImmutableList(); + _assemblyManager.SetACLToTemplateMode(_publicizedAssemblyLoader); + } + } + } + + + // get packages + IEnumerable packages = BuildPackagesList(); + + // check and load config + _packageRunConfigs.AddRange(packages + .Select(p => new KeyValuePair(p, GetRunConfigForPackage(p))) + .ToDictionary(p => p.Key, p=> p.Value)); + + // filter not to be loaded + var cpToRun = _packageRunConfigs + .Where(kvp => ShouldRunPackage(kvp.Key, kvp.Value)) + .Select(kvp => kvp.Key) + .ToImmutableList(); + + // build dependencies map + bool reliableMap = TryBuildDependenciesMap(cpToRun, out var packDeps); + if (!reliableMap) + { + ModUtils.Logging.PrintMessage($"{nameof(CsPackageManager)}: Unable to create reliable dependencies map."); + } + + _packagesDependencies.AddRange(packDeps.ToDictionary( + kvp => kvp.Key, + kvp => kvp.Value.ToImmutableList()) + ); + + List packagesToLoadInOrder = new(); + + // build load order + if (reliableMap && OrderAndFilterPackagesByDependencies( + _packagesDependencies, + out var readyToLoad, + out var cannotLoadPackages, + null)) + { + packagesToLoadInOrder.AddRange(readyToLoad); + if (cannotLoadPackages is not null) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the following mods due to dependency errors:"); + foreach (var pair in cannotLoadPackages) + { + ModUtils.Logging.PrintError($"Package: {pair.Key.Name} | Reason: {pair.Value}"); + } + } + } + else + { + // use unsorted list on failure and send error message. + packagesToLoadInOrder.AddRange(_packagesDependencies.Select( p=> p.Key)); + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to create a reliable load order. Defaulting to unordered loading!"); + } + + // get assemblies and scripts' filepaths from packages + var toLoad = packagesToLoadInOrder + .Select(cp => new KeyValuePair( + cp, + new LoadableData( + TryScanPackagesForAssemblies(cp, out var list1) ? list1 : null, + TryScanPackageForScripts(cp, out var list2) ? list2 : null))) + .ToImmutableDictionary(); + + HashSet badPackages = new(); + foreach (var pair in toLoad) + { + // check if unloadable + if (badPackages.Contains(pair.Key)) + continue; + + // try load binary assemblies + var id = Guid.Empty; // id for the ACL for this package defined by AssemblyManager. + AssemblyLoadingSuccessState successState; + if (pair.Value.AssembliesFilePaths is not null && pair.Value.AssembliesFilePaths.Any()) + { + ModUtils.Logging.PrintMessage($"Loading assemblies for CPackage {pair.Key.Name}"); +#if DEBUG + foreach (string assembliesFilePath in pair.Value.AssembliesFilePaths) + { + ModUtils.Logging.PrintMessage($"Found assemblies located at {Path.GetFullPath(ModUtils.IO.SanitizePath(assembliesFilePath))}"); + } +#endif + + successState = _assemblyManager.LoadAssembliesFromLocations(pair.Value.AssembliesFilePaths, ref id); + + // error handling + if (successState is not AssemblyLoadingSuccessState.Success) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the binary assemblies for package {pair.Key.Name}. Error: {successState.ToString()}"); + UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies); + continue; + } + } + + // try compile scripts to assemblies + if (pair.Value.ScriptsFilePaths is not null && pair.Value.ScriptsFilePaths.Any()) + { + ModUtils.Logging.PrintMessage($"Loading scripts for CPackage {pair.Key.Name}"); + List syntaxTrees = new(); + + syntaxTrees.Add(GetPackageScriptImports()); + bool abortPackage = false; + // load scripts data from files + foreach (string scriptPath in pair.Value.ScriptsFilePaths) + { + var state = ModUtils.IO.GetOrCreateFileText(scriptPath, out string fileText, null, false); + // could not load file data + if (state is not ModUtils.IO.IOActionResultState.Success) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the script files for package {pair.Key.Name}. Error: {state.ToString()}"); + UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies); + abortPackage = true; + break; + } + + try + { + CancellationToken token = new(); + syntaxTrees.Add(SyntaxFactory.ParseSyntaxTree(fileText, ScriptParseOptions, scriptPath, Encoding.Default, token)); + // cancel if parsing failed + if (token.IsCancellationRequested) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the script files for package {pair.Key.Name}. Error: Syntax Parse Error."); + UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies); + abortPackage = true; + break; + } + } + catch (Exception e) + { + // unknown error + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the script files for package {pair.Key.Name}. Error: {e.Message}"); + UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies); + abortPackage = true; + break; + } + + } + + if (abortPackage) + continue; + + // try compile + successState = _assemblyManager.LoadAssemblyFromMemory( + pair.Key.Name.Replace(" ",""), + syntaxTrees, + null, + CompilationOptions, + ref id, publicizedAssemblies); + + if (successState is not AssemblyLoadingSuccessState.Success) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to compile script assembly for package {pair.Key.Name}. Error: {successState.ToString()}"); + UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies); + continue; + } + } + + // something was loaded, add to index + if (id != Guid.Empty) + { + ModUtils.Logging.PrintMessage($"Assemblies from CPackage {pair.Key.Name} loaded with Guid {id}."); + _loadedCompiledPackageAssemblies.Add(pair.Key, id); + _reverseLookupGuidList.Add(id, pair.Key); + } + } + + // update loaded packages to exclude bad packages + _currentPackagesByLoadOrder.AddRange(toLoad + .Where(p => !badPackages.Contains(p.Key)) + .Select(p => p.Key)); + + // build list of plugins + foreach (var pair in _loadedCompiledPackageAssemblies) + { + if (_assemblyManager.TryGetSubTypesFromACL(pair.Value, out var types)) + { + _pluginTypes[pair.Value] = types.ToImmutableHashSet(); + foreach (var type in _pluginTypes[pair.Value]) + { + ModUtils.Logging.PrintMessage($"Loading type: {type.Name}"); + } + } + } + + this.AssembliesLoaded = true; + return AssemblyLoadingSuccessState.Success; + + + bool ShouldRunPackage(ContentPackage package, RunConfig config) + { + if (config.AutoGenerated) + return false; + return (!_luaCsSetup.Config.TreatForcedModsAsNormal && config.IsForced()) + || (ContentPackageManager.EnabledPackages.All.Contains(package) && config.IsForcedOrStandard()); + } + + void UpdatePackagesToDisable(ref HashSet list, + ContentPackage newDisabledPackage, + IEnumerable>> dependenciesMap) + { + list.Add(newDisabledPackage); + foreach (var package in dependenciesMap) + { + if (package.Value.Contains(newDisabledPackage)) + list.Add(newDisabledPackage); + } + } + } + + /// + /// Executes instantiated plugins' Initialize() and OnLoadCompleted() methods. + /// + public void RunPluginsInit() + { + if (!AssembliesLoaded) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to call plugins' Initialize() without any loaded assemblies!"); + return; + } + + if (!PluginsInitialized) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to call plugins' Initialize() without type instantiation!"); + return; + } + + if (PluginsLoaded) + return; + + foreach (var contentPlugins in _loadedPlugins) + { + // init + foreach (var plugin in contentPlugins.Value) + { + TryRun(() => plugin.Initialize(), $"{nameof(IAssemblyPlugin.Initialize)}", plugin.GetType().Name); + } + } + + foreach (var contentPlugins in _loadedPlugins) + { + // load complete + foreach (var plugin in contentPlugins.Value) + { + TryRun(() => plugin.OnLoadCompleted(), $"{nameof(IAssemblyPlugin.OnLoadCompleted)}", plugin.GetType().Name); + } + } + + PluginsLoaded = true; + } + + /// + /// Executes instantiated plugins' PreInitPatching() method. + /// + public void RunPluginsPreInit() + { + if (!AssembliesLoaded) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to call plugins' PreInitPatching() without any loaded assemblies!"); + return; + } + + if (!PluginsInitialized) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to call plugins' PreInitPatching() without type initialization!"); + return; + } + + if (PluginsPreInit) + { + return; + } + + foreach (var contentPlugins in _loadedPlugins) + { + // init + foreach (var plugin in contentPlugins.Value) + { + TryRun(() => plugin.PreInitPatching(), $"{nameof(IAssemblyPlugin.PreInitPatching)}", plugin.GetType().Name); + } + } + + PluginsPreInit = true; + } + + /// + /// Initializes plugin types that are registered. + /// + /// + public void InstantiatePlugins(bool force = false) + { + if (!AssembliesLoaded) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to instantiate plugins without any loaded assemblies!"); + return; + } + + if (PluginsInitialized) + { + if (force) + UnloadPlugins(); + else + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to load plugins when they were already loaded!"); + return; + } + } + + foreach (var pair in _pluginTypes) + { + // instantiate + foreach (Type type in pair.Value) + { + if (!_loadedPlugins.ContainsKey(pair.Key)) + _loadedPlugins.Add(pair.Key, new()); + else if (_loadedPlugins[pair.Key] is null) + _loadedPlugins[pair.Key] = new(); + IAssemblyPlugin plugin = null; + try + { + plugin = (IAssemblyPlugin)Activator.CreateInstance(type); + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Error while instantiating plugin of type {type}. Now disposing..."); +#if DEBUG + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Details: {e.Message} | {e.InnerException}"); +#endif + TryRun(() => plugin?.Dispose(), "Dispose", type.FullName ?? type.Name); + + plugin = null; + } + if (plugin is not null) + _loadedPlugins[pair.Key].Add(plugin); + else + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Error while instantiating plugin of type {type}"); + } + } + + PluginsInitialized = true; + } + + /// + /// Unloads all plugins by calling Dispose() on them. Note: This does not remove their external references nor + /// unregister their types. + /// + public void UnloadPlugins() + { + foreach (var contentPlugins in _loadedPlugins) + { + foreach (var plugin in contentPlugins.Value) + { + TryRun(() => plugin.Dispose(), $"{nameof(IAssemblyPlugin.Dispose)}", plugin.GetType().Name); + } + contentPlugins.Value.Clear(); + } + + _loadedPlugins.Clear(); + + PluginsInitialized = false; + PluginsPreInit = false; + PluginsLoaded = false; + } + + + /// + /// Gets the RunConfig.xml for the given package located at [cp_root]/CSharp/RunConfig.xml. + /// Generates a default config if one is not found. + /// + /// The package to search for. + /// RunConfig data. + /// True if a config is loaded, false if one was created. + public static bool GetOrCreateRunConfig(ContentPackage package, out RunConfig config) + { + var path = System.IO.Path.Combine(Path.GetFullPath(package.Dir), "CSharp", "RunConfig.xml"); + if (!File.Exists(path)) + { + config = new RunConfig(true).Sanitize(); + return false; + } + return ModUtils.IO.LoadOrCreateTypeXml(out config, path, () => new RunConfig(true).Sanitize(), false); + } + + #endregion + + #region INTERNALS + + private void TryRun(Action action, string messageMethodName, string messageTypeName) + { + try + { + action?.Invoke(); + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Error while running {messageMethodName}() on plugin of type {messageTypeName}"); +#if DEBUG + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Details: {e.Message} | {e.InnerException}"); +#endif + } + } + + private void AssemblyManagerOnAssemblyUnloading(Assembly assembly) + { + ReflectionUtils.RemoveAssemblyFromCache(assembly); + } + + private void AssemblyManagerOnAssemblyLoaded(Assembly assembly) + { + //ReflectionUtils.AddNonAbstractAssemblyTypes(assembly); + // As ReflectionUtils.GetDerivedNonAbstract is only used for Prefabs & Barotrauma-specific implementing types, + // we can safely not register System/Core assemblies. + if (assembly.FullName is not null && assembly.FullName.StartsWith("System.")) + return; + ReflectionUtils.AddNonAbstractAssemblyTypes(assembly, true); + } + + internal CsPackageManager([NotNull] AssemblyManager assemblyManager, [NotNull] LuaCsSetup luaCsSetup) + { + this._assemblyManager = assemblyManager; + this._luaCsSetup = luaCsSetup; + } + + ~CsPackageManager() + { + this.Dispose(); + } + + private static bool TryScanPackageForScripts(ContentPackage package, out ImmutableList scriptFilePaths) + { + string pathShared = Path.Combine(ModUtils.IO.GetContentPackageDir(package), "CSharp", "Shared"); + string pathArch = Path.Combine(ModUtils.IO.GetContentPackageDir(package), "CSharp", ARCHITECTURE_TARGET); + + List files = new(); + + if (Directory.Exists(pathShared)) + files.AddRange(Directory.GetFiles(pathShared, SCRIPT_FILE_REGEX, SearchOption.AllDirectories)); + if (Directory.Exists(pathArch)) + files.AddRange(Directory.GetFiles(pathArch, SCRIPT_FILE_REGEX, SearchOption.AllDirectories)); + + if (files.Count > 0) + { + scriptFilePaths = files.ToImmutableList(); + return true; + } + scriptFilePaths = ImmutableList.Empty; + return false; + } + + private static bool TryScanPackagesForAssemblies(ContentPackage package, out ImmutableList assemblyFilePaths) + { + string path = Path.Combine(ModUtils.IO.GetContentPackageDir(package), "bin", ARCHITECTURE_TARGET, PLATFORM_TARGET); + + if (!Directory.Exists(path)) + { + assemblyFilePaths = ImmutableList.Empty; + return false; + } + + assemblyFilePaths = System.IO.Directory.GetFiles(path, ASSEMBLY_FILE_REGEX, SearchOption.AllDirectories) + .ToImmutableList(); + return assemblyFilePaths.Count > 0; + } + + private static RunConfig GetRunConfigForPackage(ContentPackage package) + { + if (!GetOrCreateRunConfig(package, out var config)) + config.AutoGenerated = true; + return config; + } + + private IEnumerable BuildPackagesList() + { + // get unique list of content packages. + // Note: there is an old issue where the AllPackages group + // would sometimes not contain packages downloaded from the host, so we union enabled. + return ContentPackageManager.AllPackages.Union(ContentPackageManager.EnabledPackages.All).Where(pack => !pack.Name.ToLowerInvariant().Equals("vanilla")); + } + + + private static SyntaxTree GetPackageScriptImports() => BaseAssemblyImports; + + + /// + /// Builds a list of ContentPackage dependencies for each of the packages in the list. Note: All dependencies must be included in the provided list of packages. + /// + /// List of packages to check + /// Dependencies by package + /// True if all dependencies were found. + private static bool TryBuildDependenciesMap(ImmutableList packages, out Dictionary> dependenciesMap) + { + bool reliableMap = true; // remains true if all deps were found. + dependenciesMap = new(); + foreach (var package in packages) + { + dependenciesMap.Add(package, new()); + if (GetOrCreateRunConfig(package, out var config)) + { + if (config.Dependencies is null || !config.Dependencies.Any()) + continue; + + foreach (RunConfig.Dependency dependency in config.Dependencies) + { + ContentPackage dep = packages.FirstOrDefault(p => + (dependency.SteamWorkshopId != 0 && p.TryExtractSteamWorkshopId(out var steamWorkshopId) + && steamWorkshopId.Value == dependency.SteamWorkshopId) + || (!dependency.PackageName.IsNullOrWhiteSpace() && p.Name.ToLowerInvariant().Contains(dependency.PackageName.ToLowerInvariant())), null); + + if (dep is not null) + { + dependenciesMap[package].Add(dep); + } + else + { + ModUtils.Logging.PrintError($"Warning! The ContentPackage {package.Name} lists a dependency of (STEAMID: {dependency.SteamWorkshopId}, PackageName: {dependency.PackageName}) but it could not be found in the to-be-loaded CSharp packages list!"); + reliableMap = false; + } + } + } + else + { + ModUtils.Logging.PrintMessage($"Warning! Could not retrieve RunConfig for ContentPackage {package.Name}!"); + } + } + + return reliableMap; + } + + /// + /// Given a table of packages and dependent packages, will sort them by dependency loading order along with packages + /// that cannot be loaded due to errors or failing the predicate checks. + /// + /// A dictionary/map with key as the package and the elements as it's dependencies. + /// List of packages that are ready to load and in the correct order. + /// Packages with errors or cyclic dependencies. Element is error message. Null if empty. + /// Optional: Allows for a custom checks to be performed on each package. + /// Returns a bool indicating if the package is ready to load. + /// Whether or not the process produces a usable list. + private static bool OrderAndFilterPackagesByDependencies( + Dictionary> packages, + out IEnumerable readyToLoad, + out IEnumerable> cannotLoadPackages, + Func packageChecksPredicate = null) + { + HashSet completedPackages = new(); + List readyPackages = new(); + Dictionary unableToLoad = new(); + HashSet currentNodeChain = new(); + + readyToLoad = readyPackages; + + try + { + foreach (var toProcessPack in packages) + { + ProcessPackage(toProcessPack.Key, toProcessPack.Value); + } + + PackageProcRet ProcessPackage(ContentPackage packageToProcess, IEnumerable dependencies) + { + //cyclic handling + if (unableToLoad.ContainsKey(packageToProcess)) + { + return PackageProcRet.BadPackage; + } + + // already processed + if (completedPackages.Contains(packageToProcess)) + { + return PackageProcRet.AlreadyCompleted; + } + + // cyclic check + if (currentNodeChain.Contains(packageToProcess)) + { + StringBuilder sb = new(); + sb.AppendLine("Error: Cyclic Dependency. ") + .Append( + "The following ContentPackages rely on eachother in a way that makes it impossible to know which to load first! ") + .Append( + "Note: the package listed twice shows where the cycle starts/ends and is not necessarily the problematic package."); + int i = 0; + foreach (var package in currentNodeChain) + { + i++; + sb.AppendLine($"{i}. {package.Name}"); + } + + sb.AppendLine($"{i}. {packageToProcess.Name}"); + unableToLoad.Add(packageToProcess, sb.ToString()); + completedPackages.Add(packageToProcess); + return PackageProcRet.BadPackage; + } + + if (packageChecksPredicate is not null && !packageChecksPredicate.Invoke(packageToProcess)) + { + unableToLoad.Add(packageToProcess, $"Unable to load package {packageToProcess.Name} due to failing checks."); + completedPackages.Add(packageToProcess); + return PackageProcRet.BadPackage; + } + + currentNodeChain.Add(packageToProcess); + + foreach (ContentPackage dependency in dependencies) + { + // The mod lists a dependent that was not found during the discovery phase. + if (!packages.ContainsKey(dependency)) + { + // search to see if it's enabled + if (!ContentPackageManager.EnabledPackages.All.Contains(dependency)) + { + // present warning but allow loading anyways, better to let the user just disable the package if it's really an issue. + ModUtils.Logging.PrintError( + $"Warning: the ContentPackage of {packageToProcess.Name} requires the Dependency {dependency.Name} but this package wasn't found in the enabled mods list!"); + } + + continue; + } + + var ret = ProcessPackage(dependency, packages[dependency]); + + if (ret is PackageProcRet.BadPackage) + { + if (!unableToLoad.ContainsKey(packageToProcess)) + { + unableToLoad.Add(packageToProcess, $"Error: Dependency failure. Failed to load {dependency.Name}"); + } + currentNodeChain.Remove(packageToProcess); + if (!completedPackages.Contains(packageToProcess)) + { + completedPackages.Add(packageToProcess); + } + return PackageProcRet.BadPackage; + } + } + + currentNodeChain.Remove(packageToProcess); + completedPackages.Add(packageToProcess); + readyPackages.Add(packageToProcess); + return PackageProcRet.Completed; + } + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"Error while generating dependency loading order! Exception: {e.Message}"); +#if DEBUG + ModUtils.Logging.PrintError($"Stack Trace: {e.StackTrace}"); +#endif + cannotLoadPackages = unableToLoad.Any() ? unableToLoad : null; + return false; + } + cannotLoadPackages = unableToLoad.Any() ? unableToLoad : null; + return true; + } + + private enum PackageProcRet : byte + { + AlreadyCompleted, + Completed, + BadPackage + } + + private record LoadableData(ImmutableList AssembliesFilePaths, ImmutableList ScriptsFilePaths); + + #endregion +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/IAssemblyPlugin.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/IAssemblyPlugin.cs new file mode 100644 index 0000000000..5a450ba744 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/IAssemblyPlugin.cs @@ -0,0 +1,22 @@ +using System; + +namespace Barotrauma; + +public interface IAssemblyPlugin : IDisposable +{ + /// + /// Called on plugin normal, use this for basic/core loading that does not rely on any other modded content. + /// + void Initialize(); + + /// + /// Called once all plugins have been loaded. if you have integrations with any other mod, put that code here. + /// + void OnLoadCompleted(); + + + /// + /// Called before Barotrauma initializes vanilla content. WARNING: This method may be called before Initialize()! + /// + void PreInitPatching(); +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/MemoryFileAssemblyContextLoader.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/MemoryFileAssemblyContextLoader.cs new file mode 100644 index 0000000000..8c29f07e44 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/MemoryFileAssemblyContextLoader.cs @@ -0,0 +1,289 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.Loader; +using System.Threading; +using Barotrauma; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Emit; + +namespace Barotrauma; + +/// +/// AssemblyLoadContext to compile from syntax trees in memory and to load from disk/file. Provides dependency resolution. +/// [IMPORTANT] Only supports 1 in-memory compiled assembly at a time. Use more instances if you need more. +/// [IMPORTANT] All file assemblies required for the compilation of syntax trees should be loaded first. +/// +public class MemoryFileAssemblyContextLoader : AssemblyLoadContext +{ + // public + // ReSharper disable MemberCanBePrivate.Global + public Assembly CompiledAssembly { get; private set; } = null; + public byte[] CompiledAssemblyImage { get; private set; } = null; + // ReSharper restore MemberCanBePrivate.Global + // internal + private readonly Dictionary _dependencyResolvers = new(); // path-folder, resolver + protected bool IsResolving; //this is to avoid circular dependency lookup. + private AssemblyManager _assemblyManager; + public bool IsTemplateMode { get; set; } = false; + + public MemoryFileAssemblyContextLoader(AssemblyManager assemblyManager) : base(isCollectible: true) + { + this._assemblyManager = assemblyManager; + } + + + /// + /// Try to load the list of disk-file assemblies. + /// + /// Operation success or failure reason. + public AssemblyLoadingSuccessState LoadFromFiles([NotNull] IEnumerable assemblyFilePaths) + { + if (assemblyFilePaths is null) + throw new ArgumentNullException( + $"{nameof(MemoryFileAssemblyContextLoader)}::{nameof(LoadFromFiles)}() | The supplied filepath list is null."); + + foreach (string filepath in assemblyFilePaths) + { + // path verification + if (filepath.IsNullOrWhiteSpace()) + continue; + string sanitizedFilePath = System.IO.Path.GetFullPath(filepath.CleanUpPath()); + string directoryKey = System.IO.Path.GetDirectoryName(sanitizedFilePath); + + if (directoryKey is null) + return AssemblyLoadingSuccessState.BadFilePath; + + // setup dep resolver if not available + if (!_dependencyResolvers.ContainsKey(directoryKey) || _dependencyResolvers[directoryKey] is null) + { + _dependencyResolvers[directoryKey] = new AssemblyDependencyResolver(sanitizedFilePath); // supply the first assembly to be loaded + } + + // try loading the assemblies + try + { + LoadFromAssemblyPath(sanitizedFilePath); + } + // on fail of any we're done because we assume that loaded files are related. This ACL needs to be unloaded and collected. + catch (ArgumentNullException ane) + { + return AssemblyLoadingSuccessState.BadFilePath; + } + catch (ArgumentException ae) + { + return AssemblyLoadingSuccessState.BadFilePath; + } + catch (FileLoadException fle) + { + return AssemblyLoadingSuccessState.CannotLoadFile; + } + catch (FileNotFoundException fne) + { + return AssemblyLoadingSuccessState.NoAssemblyFound; + } + catch (BadImageFormatException bfe) + { + return AssemblyLoadingSuccessState.InvalidAssembly; + } + catch (Exception e) + { +#if SERVER + LuaCsLogger.LogError($"Unable to load dependency assembly file at {filepath.CleanUpPath()} for the assembly named {CompiledAssembly?.FullName}. | Data: {e.Message} | InnerException: {e.InnerException}"); +#elif CLIENT + LuaCsLogger.ShowErrorOverlay($"Unable to load dependency assembly file at {filepath} for the assembly named {CompiledAssembly?.FullName}. | Data: {e.Message} | InnerException: {e.InnerException}"); +#endif + return AssemblyLoadingSuccessState.ACLLoadFailure; + } + } + + return AssemblyLoadingSuccessState.Success; + } + + + /// + /// Compiles the supplied syntaxtrees and options into an in-memory assembly image. + /// Builds metadata from loaded assemblies, only supply your own if you have in-memory images not managed by the + /// AssemblyManager class. + /// + /// Name of the assembly. Must be supplied for in-memory assemblies. + /// Syntax trees to compile into the assembly. + /// Metadata to be used for compilation. + /// [IMPORTANT] This method builds metadata from loaded assemblies, only supply your own if you have in-memory + /// images not managed by the AssemblyManager class. + /// CSharp compilation options. This method automatically adds the 'IgnoreAccessChecks' property for compilation. + /// Will contain any diagnostic messages for compilation failure. + /// Additional assemblies located in the FileSystem to build metadata references from. + /// Assemblies here will have duplicates by the same name that are currently loaded filtered out. + /// Success state of the operation. + /// Throws exception if any of the required arguments are null. + public AssemblyLoadingSuccessState CompileAndLoadScriptAssembly( + [NotNull] string assemblyName, + [NotNull] IEnumerable syntaxTrees, + IEnumerable externMetadataReferences, + [NotNull] CSharpCompilationOptions compilationOptions, + out string compilationMessages, + IEnumerable externFileAssemblyReferences = null) + { + compilationMessages = ""; + + if (this.CompiledAssembly is not null) + { + return AssemblyLoadingSuccessState.AlreadyLoaded; + } + + var externAssemblyRefs = externFileAssemblyReferences is not null ? externFileAssemblyReferences.ToImmutableList() : ImmutableList.Empty; + var externAssemblyNames = externAssemblyRefs.Any() ? externAssemblyRefs + .Where(a => a.FullName is not null) + .Select(a => a.FullName).ToImmutableHashSet() + : ImmutableHashSet.Empty; + + // verifications + if (assemblyName.IsNullOrWhiteSpace()) + throw new ArgumentNullException( + $"{nameof(MemoryFileAssemblyContextLoader)}::{nameof(CompileAndLoadScriptAssembly)}() | The supplied assembly name is null!"); + + if (syntaxTrees is null) + throw new ArgumentNullException( + $"{nameof(MemoryFileAssemblyContextLoader)}::{nameof(CompileAndLoadScriptAssembly)}() | The supplied syntax tree is null!"); + + // add external references + List metadataReferences = new(); + if (externMetadataReferences is not null) + metadataReferences.AddRange(externMetadataReferences); + + // build metadata refs from global where not an in-memory compiled assembly and not the same assembly as supplied. + metadataReferences.AddRange(AppDomain.CurrentDomain.GetAssemblies() + .Where(a => + { + if (a.IsDynamic || string.IsNullOrEmpty(a.Location) || a.Location.Contains("xunit")) + return false; + if (a.FullName is null) + return true; + return !externAssemblyNames.Contains(a.FullName); // exclude duplicates + }) + .Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference) + .Union(externAssemblyRefs // add custom supplied assemblies + .Where(a => !(a.IsDynamic || string.IsNullOrEmpty(a.Location) || a.Location.Contains("xunit"))) + .Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference) + ).ToList()); + + // build metadata refs from in-memory images + foreach (var loadedAcl in _assemblyManager.GetAllLoadedACLs()) + { + if (loadedAcl.Acl.CompiledAssemblyImage is null || loadedAcl.Acl.CompiledAssemblyImage.Length == 0) + continue; + metadataReferences.Add(MetadataReference.CreateFromImage(loadedAcl.Acl.CompiledAssemblyImage)); + } + + // Change inaccessible options to allow public access to restricted members + var topLevelBinderFlagsProperty = typeof(CSharpCompilationOptions).GetProperty("TopLevelBinderFlags", BindingFlags.Instance | BindingFlags.NonPublic); + topLevelBinderFlagsProperty?.SetValue(compilationOptions, (uint)1 << 22); + + // begin compilation + using var memoryCompilation = new MemoryStream(); + // compile, emit + var result = CSharpCompilation.Create(assemblyName, syntaxTrees, metadataReferences, compilationOptions).Emit(memoryCompilation); + // check for errors + if (!result.Success) + { + IEnumerable failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error); + foreach (Diagnostic diagnostic in failures) + { + compilationMessages += $"\n{diagnostic}"; + } + + return AssemblyLoadingSuccessState.InvalidAssembly; + } + + // read compiled assembly from memory stream into an in-memory assembly & image + memoryCompilation.Seek(0, SeekOrigin.Begin); // reset + try + { + CompiledAssembly = LoadFromStream(memoryCompilation); + CompiledAssemblyImage = memoryCompilation.ToArray(); + } + catch (Exception e) + { +#if SERVER + LuaCsLogger.LogError($"Unable to load memory assembly from stream. | Data: {e.Message} | InnerException: {e.InnerException}"); +#elif CLIENT + LuaCsLogger.ShowErrorOverlay($"Unable to load memory assembly from stream. | Data: {e.Message} | InnerException: {e.InnerException}"); +#endif + return AssemblyLoadingSuccessState.CannotLoadFromStream; + } + + return AssemblyLoadingSuccessState.Success; + } + + [SuppressMessage("ReSharper", "ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract")] + protected override Assembly Load(AssemblyName assemblyName) + { + if (IsResolving) + return null; //circular resolution fast exit. + + try + { + IsResolving = true; + + // resolve self collection + Assembly ass = this.Assemblies.FirstOrDefault(a => + a.FullName is not null && a.FullName.Equals(assemblyName.FullName), null); + if (ass is not null) + return ass; + + // resolve to local folders + foreach (KeyValuePair pair in _dependencyResolvers) + { + var asspath = pair.Value.ResolveAssemblyToPath(assemblyName); + if (asspath is null) + continue; + ass = LoadFromAssemblyPath(asspath); + // ReSharper disable once ConditionIsAlwaysTrueOrFalse + if (ass is not null) + return ass; + } + + //try resolve against other loaded alcs + foreach (var loadedAcL in _assemblyManager.GetAllLoadedACLs()) + { + if (loadedAcL.Acl is null || loadedAcL.Acl.IsTemplateMode) continue; + + try + { + ass = loadedAcL.Acl.LoadFromAssemblyName(assemblyName); + if (ass is not null) + return ass; + } + catch + { + // LoadFromAssemblyName throws, no need to propagate + } + } + + ass = AssemblyLoadContext.Default.LoadFromAssemblyName(assemblyName); + if (ass is not null) + return ass; + } + finally + { + IsResolving = false; + } + + return null; + } + + + private new void Unload() + { + CompiledAssembly = null; + CompiledAssemblyImage = null; + base.Unload(); + } +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/RunConfig.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/RunConfig.cs new file mode 100644 index 0000000000..93cbb75f78 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/RunConfig.cs @@ -0,0 +1,111 @@ +using System; +using System.Xml.Serialization; + +namespace Barotrauma; + +[Serializable] +public sealed class RunConfig +{ + /// + /// How should scripts be run on the server. + /// + [XmlElement(ElementName = "Server")] public string Server; + + /// + /// How should scripts be run on the client. + /// + [XmlElement(ElementName = "Client")] public string Client; + + /// + /// List of dependencies by either Steam Workshop ID or by Partial Inclusive Name (ie. "ModDep" will match a mod named "A ModDependency"). + /// PIN Dependency checks if ContentPackage names contains the dependency string. + /// + [XmlArrayItem(ElementName = "Dependency", IsNullable = true, Type = typeof(Dependency))] + [XmlArray] + public Dependency[] Dependencies { get; set; } + + [XmlElement(ElementName = "AutoGenerated")] + public bool AutoGenerated { get; set; } + + public RunConfig(bool autoGenerated) + { + this.AutoGenerated = autoGenerated; + if (autoGenerated) + { + (Client, Server) = ("None", "None"); + } + } + + public RunConfig() { } // For serialization use + + [Serializable] + public sealed class Dependency + { + /// + /// Steam Workshop ID of the dependency. + /// + [XmlElement(ElementName = "SteamWorkshopId")] + public ulong SteamWorkshopId; + + /// + /// Package Name of the dependency. Not needed if SteamWorkshopId is set. + /// + [XmlElement(ElementName = "PackageName")] + public string PackageName; + } + + public RunConfig Sanitize() + { + try + { + Client = SanitizeRunSetting(Client); + } + catch (Exception e) + { + Client = "None"; + } + + try + { + Server = SanitizeRunSetting(Server); + } + catch (Exception e) + { + Server = "None"; + } + + Dependencies ??= new RunConfig.Dependency[] { }; + + static string SanitizeRunSetting(string str) => + str switch + { + null => "None", + "" => "None", + " " => "None", + _ => str[0].ToString().ToUpper() + str.Substring(1).ToLower() + }; + + return this; + } + + public bool IsForced() + { +#if CLIENT + return this.Client.Equals("Forced"); +#elif SERVER + return this.Server.Equals("Forced"); +#endif + } + + public bool IsStandard() + { +#if CLIENT + return this.Client.Equals("Standard"); +#elif SERVER + return this.Server.Equals("Standard"); +#endif + } + + public bool IsForcedOrStandard() => this.IsForced() || this.IsStandard(); + +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/Utils/ReflectionUtils.cs b/Barotrauma/BarotraumaShared/SharedSource/Utils/ReflectionUtils.cs index f3e6bcdf7a..ac8c601ac7 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/Utils/ReflectionUtils.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/Utils/ReflectionUtils.cs @@ -1,5 +1,6 @@ #nullable enable using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; @@ -9,26 +10,43 @@ namespace Barotrauma { public static class ReflectionUtils { - private static readonly Dictionary> cachedNonAbstractTypes - = new Dictionary>(); - - private static readonly Dictionary>> cachedDerivedNonAbstract - = new Dictionary>>(); + private static readonly ConcurrentDictionary> CachedNonAbstractTypes + = new ConcurrentDictionary>(); + private static readonly ConcurrentDictionary> TypeSearchCache = new(); public static IEnumerable GetDerivedNonAbstract() { Type t = typeof(T); - Assembly assembly = typeof(T).Assembly; - lock (cachedNonAbstractTypes) + string typeName = t.FullName ?? t.Name; + + // search quick lookup cache + if (TypeSearchCache.TryGetValue(typeName, out var value)) { - if (!cachedNonAbstractTypes.ContainsKey(assembly)) - { - AddNonAbstractAssemblyTypes(assembly); - } + return value; } + + // doesn't exist so let's add it. + Assembly assembly = typeof(T).Assembly; + if (!CachedNonAbstractTypes.ContainsKey(assembly)) + { + AddNonAbstractAssemblyTypes(assembly); + } + + // build cache from registered assemblies' types. + var list = CachedNonAbstractTypes.Values + .SelectMany(arr => arr.Where(type => type.IsSubclassOf(t))) + .ToImmutableArray(); - #warning TODO: Add safety checks in case an assembly is unloaded without being removed from the cache. - return cachedNonAbstractTypes.Values.SelectMany(s => s.Where(t => t.IsSubclassOf(typeof(T)))); + if (list.Length == 0) + { + return ImmutableArray.Empty; // No types, don't add to cache + } + + if (!TypeSearchCache.TryAdd(typeName, list)) + { + DebugConsole.LogError($"ReflectionUtils::AddNonAbstractAssemblyTypes() | Error while adding to quick lookup cache."); + } + return list; } /// @@ -38,7 +56,7 @@ public static IEnumerable GetDerivedNonAbstract() /// Whether or not to overwrite an entry if the assembly already exists within it. public static void AddNonAbstractAssemblyTypes(Assembly assembly, bool overwrite = false) { - if (cachedNonAbstractTypes.ContainsKey(assembly)) + if (CachedNonAbstractTypes.ContainsKey(assembly)) { if (!overwrite) { @@ -46,15 +64,20 @@ public static void AddNonAbstractAssemblyTypes(Assembly assembly, bool overwrite $"ReflectionUtils::AddNonAbstractAssemblyTypes() | The assembly [{assembly.GetName()}] already exists in the cache."); return; } - cachedNonAbstractTypes.Remove(assembly); + + CachedNonAbstractTypes.Remove(assembly, out _); } try { - if (!cachedNonAbstractTypes.TryAdd(assembly, assembly.GetTypes().Where(t => !t.IsAbstract).ToImmutableArray())) + if (!CachedNonAbstractTypes.TryAdd(assembly, assembly.GetSafeTypes().Where(t => !t.IsAbstract).ToImmutableArray())) { DebugConsole.LogError($"ReflectionUtils::AddNonAbstractAssemblyTypes() | Unable to add types from Assembly to cache."); } + else + { + TypeSearchCache.Clear(); // Needs to be rebuilt to include potential new types + } } catch (ReflectionTypeLoadException e) { @@ -66,8 +89,22 @@ public static void AddNonAbstractAssemblyTypes(Assembly assembly, bool overwrite /// Removes an assembly from the cache for Barotrauma's Type lookup. /// /// Assembly to remove. - public static void RemoveAssemblyFromCache(Assembly assembly) => cachedNonAbstractTypes.Remove(assembly); + public static void RemoveAssemblyFromCache(Assembly assembly) + { + CachedNonAbstractTypes.Remove(assembly, out _); + TypeSearchCache.Clear(); + } + /// + /// Clears all cached assembly data and rebuilds types list only to include base Barotrauma types. + /// + internal static void ResetCache() + { + CachedNonAbstractTypes.Clear(); + CachedNonAbstractTypes.TryAdd(typeof(ReflectionUtils).Assembly, typeof(ReflectionUtils).Assembly.GetSafeTypes().ToImmutableArray()); + TypeSearchCache.Clear(); + } + public static Option ParseDerived(TInput input) where TInput : notnull where TBase : notnull { static Option none() => Option.None();