From a7c7d159f459a6c7d1ede1772df096fa91b9b648 Mon Sep 17 00:00:00 2001 From: MapleWheels Date: Thu, 7 Sep 2023 03:27:16 -0400 Subject: [PATCH] Assembly and Script Loading System for Content Packages Overhauled. - Supports Dependency System. - Supports Hot Reloading. --- .github/workflows/harden-ci-security.yml | 2 +- .github/workflows/publish-release.yml | 2 +- .../ClientSource/LuaCs/LuaCsNetworking.cs | 2 +- .../ClientSource/LuaCs/LuaCsSetup.cs | 1 + .../Lua/DefaultRegister/RegisterClient.lua | 11 +- .../Lua/DefaultRegister/RegisterShared.lua | 31 +- .../SharedSource/LuaCs/Cs/CsScriptBase.cs | 50 - .../SharedSource/LuaCs/Cs/CsScriptLoader.cs | 286 ----- .../LuaCs/Lua/LuaClasses/LuaUserData.cs | 2 +- .../SharedSource/LuaCs/LuaCsSetup.cs | 137 ++- .../SharedSource/LuaCs/LuaCsUtility.cs | 19 +- .../SharedSource/LuaCs/ModUtils.cs | 332 ++++++ .../LuaCs/{Cs => Plugins}/ACsMod.cs | 26 +- .../LuaCs/Plugins/ApplicationMode.cs | 6 + .../Plugins/AssemblyLoadingSuccessState.cs | 15 + .../LuaCs/Plugins/AssemblyManager.cs | 770 ++++++++++++++ .../LuaCs/Plugins/CsPackageManager.cs | 976 ++++++++++++++++++ .../LuaCs/Plugins/IAssemblyPlugin.cs | 22 + .../MemoryFileAssemblyContextLoader.cs | 289 ++++++ .../SharedSource/LuaCs/Plugins/RunConfig.cs | 111 ++ .../SharedSource/Utils/ReflectionUtils.cs | 71 +- .../BarotraumaTest/LuaCs/HookPatchTests.cs | 73 +- luacs-docs/lua/lua/Networking.lua | 14 +- 23 files changed, 2780 insertions(+), 468 deletions(-) delete mode 100644 Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptBase.cs delete mode 100644 Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs create mode 100644 Barotrauma/BarotraumaShared/SharedSource/LuaCs/ModUtils.cs rename Barotrauma/BarotraumaShared/SharedSource/LuaCs/{Cs => Plugins}/ACsMod.cs (60%) create mode 100644 Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ApplicationMode.cs create mode 100644 Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyLoadingSuccessState.cs create mode 100644 Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyManager.cs create mode 100644 Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/CsPackageManager.cs create mode 100644 Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/IAssemblyPlugin.cs create mode 100644 Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/MemoryFileAssemblyContextLoader.cs create mode 100644 Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/RunConfig.cs diff --git a/.github/workflows/harden-ci-security.yml b/.github/workflows/harden-ci-security.yml index cc6653ce30..2e79e253e0 100644 --- a/.github/workflows/harden-ci-security.yml +++ b/.github/workflows/harden-ci-security.yml @@ -18,4 +18,4 @@ jobs: with: ref: ${{ inputs.target }} - name: Ensure all actions are pinned to a specific commit - uses: zgosalvez/github-actions-ensure-sha-pinned-actions@555a30da2656b4a7cf47b107800bef097723363e # v2.1.3 + uses: zgosalvez/github-actions-ensure-sha-pinned-actions@f32435541e24cd6a4700a7f52bb2ec59e80603b1 # v2.1.4 diff --git a/.github/workflows/publish-release.yml b/.github/workflows/publish-release.yml index 0f408cef04..d2512e9440 100644 --- a/.github/workflows/publish-release.yml +++ b/.github/workflows/publish-release.yml @@ -157,7 +157,7 @@ jobs: done - name: Publish release - uses: notpeelz/action-gh-create-release@a12edfc71daf5daa7922b931c28e2bf88d3b2ced # v5.0.0 + uses: notpeelz/action-gh-create-release@c1bebd17c8a128e8db4165a68be4dc4e3f106ff1 # v5.0.1 with: target: ${{ inputs.target }} tag: ${{ inputs.tag }} diff --git a/Barotrauma/BarotraumaClient/ClientSource/LuaCs/LuaCsNetworking.cs b/Barotrauma/BarotraumaClient/ClientSource/LuaCs/LuaCsNetworking.cs index 61e9450d26..2ce99fa878 100644 --- a/Barotrauma/BarotraumaClient/ClientSource/LuaCs/LuaCsNetworking.cs +++ b/Barotrauma/BarotraumaClient/ClientSource/LuaCs/LuaCsNetworking.cs @@ -133,7 +133,7 @@ private void ReadIds(IReadMessage netMessage) { if (netReceives.ContainsKey(name)) { - netReceives[name](netMessage, null); + netReceives[name](queueMessage, null); } } } diff --git a/Barotrauma/BarotraumaClient/ClientSource/LuaCs/LuaCsSetup.cs b/Barotrauma/BarotraumaClient/ClientSource/LuaCs/LuaCsSetup.cs index 48c3247110..51ef1ca760 100644 --- a/Barotrauma/BarotraumaClient/ClientSource/LuaCs/LuaCsSetup.cs +++ b/Barotrauma/BarotraumaClient/ClientSource/LuaCs/LuaCsSetup.cs @@ -34,6 +34,7 @@ public void CheckInitialize() if (GameMain.Client.IsServerOwner) { new GUIMessageBox("", "You have CSharp mods enabled but don't have the Cs For Barotrauma package enabled, those mods might not work."); + Initialize(); return; } diff --git a/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterClient.lua b/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterClient.lua index cc6fa9a2e9..d0497fd6c4 100644 --- a/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterClient.lua +++ b/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterClient.lua @@ -24,11 +24,8 @@ RegisterBarotrauma("Media.Video") RegisterBarotrauma("SoundsFile") RegisterBarotrauma("SoundPrefab") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.SoundPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.BackgroundMusic]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.GUISound]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.DamageSound]]") -RegisterBarotrauma("PrefabSelector`1[[Barotrauma.SoundPrefab]]") +RegisterBarotrauma("PrefabCollection`1") +RegisterBarotrauma("PrefabSelector`1") RegisterBarotrauma("BackgroundMusic") RegisterBarotrauma("GUISound") RegisterBarotrauma("DamageSound") @@ -57,7 +54,6 @@ RegisterBarotrauma("Particles.Particle") RegisterBarotrauma("Particles.ParticleEmitterProperties") RegisterBarotrauma("Particles.ParticleEmitter") RegisterBarotrauma("Particles.ParticlePrefab") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.Particles.ParticlePrefab]]") RegisterBarotrauma("Lights.LightManager") RegisterBarotrauma("Lights.LightSource") @@ -145,4 +141,5 @@ RegisterBarotrauma("Store") RegisterBarotrauma("UISprite") RegisterBarotrauma("ParamsEditor") -RegisterBarotrauma("Inventory+SlotReference") \ No newline at end of file +RegisterBarotrauma("Inventory+SlotReference") +RegisterBarotrauma("VisualSlot") diff --git a/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterShared.lua b/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterShared.lua index f8c34e511c..0b643e0314 100644 --- a/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterShared.lua +++ b/Barotrauma/BarotraumaShared/Lua/DefaultRegister/RegisterShared.lua @@ -6,8 +6,8 @@ Register("System.Exception") Register("System.Console") Register("System.Exception") -RegisterBarotrauma("Success`2[[Barotrauma.ContentPackage],[System.Exception]]") -RegisterBarotrauma("Failure`2[[Barotrauma.ContentPackage],[System.Exception]]") +RegisterBarotrauma("Success`2") +RegisterBarotrauma("Failure`2") RegisterBarotrauma("LuaSByte") RegisterBarotrauma("LuaByte") @@ -24,8 +24,7 @@ RegisterBarotrauma("GameMain") RegisterBarotrauma("Networking.BanList") RegisterBarotrauma("Networking.BannedPlayer") -RegisterBarotrauma("Range`1[System.Single]") -RegisterBarotrauma("Range`1[System.Int32]") +RegisterBarotrauma("Range`1") RegisterBarotrauma("RichString") RegisterBarotrauma("Identifier") @@ -399,27 +398,11 @@ end RegisterBarotrauma("Camera") RegisterBarotrauma("Key") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.ItemPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.JobPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.CharacterPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.AfflictionPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.TalentPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.TalentTree]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.OrderPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.LevelGenerationParams]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.OutpostGenerationParams]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.RuinGeneration.RuinGenerationParams]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.LevelGenerationParams]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.LocationType]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.EventPrefab]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.EventSet]]") -RegisterBarotrauma("PrefabCollection`1[[Barotrauma.EventManagerSettings]]") +RegisterBarotrauma("PrefabCollection`1") -RegisterBarotrauma("PrefabSelector`1[[Barotrauma.SkillSettings]]") +RegisterBarotrauma("PrefabSelector`1") -RegisterBarotrauma("Pair`2[[Barotrauma.JobPrefab],[System.Int32]]") - -RegisterBarotrauma("Range`1[System.Single]") +RegisterBarotrauma("Pair`2") RegisterBarotrauma("Items.Components.Signal") RegisterBarotrauma("SubmarineInfo") @@ -461,4 +444,4 @@ LuaUserData.RemoveMember(workshopItem, "AddFavorite") LuaUserData.RemoveMember(workshopItem, "RemoveFavorite") LuaUserData.RemoveMember(workshopItem, "Vote") LuaUserData.RemoveMember(workshopItem, "GetUserVote") -LuaUserData.RemoveMember(workshopItem, "Edit") \ No newline at end of file +LuaUserData.RemoveMember(workshopItem, "Edit") diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptBase.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptBase.cs deleted file mode 100644 index 6ac4b8f9e6..0000000000 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptBase.cs +++ /dev/null @@ -1,50 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using Microsoft.CodeAnalysis.Scripting; -using System.Reflection; -using Microsoft.CodeAnalysis.CSharp; -using System.Linq; -using Microsoft.CodeAnalysis; -using System.Runtime.Loader; -using System.Reflection.PortableExecutable; -using System.Reflection.Metadata; -using System.Text; -using System.Runtime.CompilerServices; - -namespace Barotrauma -{ - class CsScriptBase : AssemblyLoadContext - { - - public const string CsScriptAssembly = "NetScriptAssembly"; - - public static readonly string[] LoadedAssemblyName = { - CsScriptBase.CsScriptAssembly - }; - - public static Dictionary Revision = new Dictionary() - { - { CsScriptAssembly, 0} - }; - - public CSharpParseOptions ParseOptions { get; protected set; } - - public CsScriptBase() : base(isCollectible: true) { - ParseOptions = CSharpParseOptions.Default - .WithPreprocessorSymbols(new[] { LuaCsSetup.IsServer ? "SERVER" : (LuaCsSetup.IsClient ? "CLIENT" : "UNDEFINED") }); - } - - public static SyntaxTree AssemblyInfoSyntaxTree(string asmName = null) - { - Revision[asmName] = (int)Revision[asmName] + 1; - var asmInfo = new StringBuilder(); - asmInfo.AppendLine("using System.Reflection;"); - asmInfo.AppendLine($"[assembly: AssemblyMetadata(\"Revision\", \"{Revision[asmName]}\")]"); - asmInfo.AppendLine($"[assembly: AssemblyVersion(\"0.0.0.{Revision[asmName]}\")]"); - return CSharpSyntaxTree.ParseText(asmInfo.ToString(), CSharpParseOptions.Default); - } - - ~CsScriptBase() { } - } -} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs deleted file mode 100644 index 90d858f822..0000000000 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs +++ /dev/null @@ -1,286 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using Microsoft.CodeAnalysis.Scripting; -using System.Reflection; -using Microsoft.CodeAnalysis.CSharp; -using System.Linq; -using Microsoft.CodeAnalysis; -using System.Runtime.Loader; -using System.Reflection.PortableExecutable; -using System.Reflection.Metadata; -using System.Text.RegularExpressions; -using System.Xml.Linq; - -namespace Barotrauma -{ - class CsScriptLoader : CsScriptBase - { - private List defaultReferences; - - private Dictionary> sources; - public Assembly Assembly { get; private set; } - - public CsScriptLoader() - { - defaultReferences = AppDomain.CurrentDomain.GetAssemblies() - .Where(a => !(a.IsDynamic || string.IsNullOrEmpty(a.Location) || a.Location.Contains("xunit"))) - .Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference) - .ToList(); - - sources = new Dictionary>(); - Assembly = null; - } - - private enum RunType { Standard, Forced, None }; - private bool ShouldRun(ContentPackage cp, string path) - { - if (!Directory.Exists(path + "CSharp")) - { - return false; - } - - var isEnabled = ContentPackageManager.EnabledPackages.All.Contains(cp); - if (File.Exists(path + "CSharp/RunConfig.xml")) - { - Stream stream = File.Open(path + "CSharp/RunConfig.xml", FileMode.Open, FileAccess.Read, FileShare.ReadWrite); - var doc = XDocument.Load(stream); - var elems = doc.Root.Elements().ToArray(); - var elem = elems.FirstOrDefault(e => e.Name.LocalName.Equals(LuaCsSetup.IsServer ? "Server" : (LuaCsSetup.IsClient ? "Client" : "None"), StringComparison.OrdinalIgnoreCase)); - - if (elem != null && Enum.TryParse(elem.Value, true, out RunType rtValue)) - { - if (rtValue == RunType.Standard && isEnabled) - { - LuaCsLogger.LogMessage($"Added {cp.Name} {cp.ModVersion} to Cs compilation. (Standard)"); - return true; - } - else if (rtValue == RunType.Forced && (isEnabled || !GameMain.LuaCs.Config.TreatForcedModsAsNormal)) - { - LuaCsLogger.LogMessage($"Added {cp.Name} {cp.ModVersion} to Cs compilation. (Forced)"); - return true; - } - else if (rtValue == RunType.None) - { - return false; - } - } - - stream.Close(); - } - - if (isEnabled) - { - LuaCsLogger.LogMessage($"Added {cp.Name} {cp.ModVersion} to Cs compilation. (Assumed)"); - return true; - } - else - { - return false; - } - } - - public void SearchFolders() - { - var packagesAdded = new HashSet(); - var paths = new Dictionary(); - foreach (var cp in ContentPackageManager.AllPackages.Concat(ContentPackageManager.EnabledPackages.All)) - { - if (packagesAdded.Contains(cp)) { continue; } - var path = $"{Path.GetFullPath(Path.GetDirectoryName(cp.Path)).Replace('\\', '/')}/"; - if (ShouldRun(cp, path)) - { - if (paths.ContainsKey(cp.Name)) - { - if (ContentPackageManager.EnabledPackages.All.Contains(cp)) - { - paths[cp.Name] = path; - } - } - else - { - paths.Add(cp.Name, path); - } - packagesAdded.Add(cp); - } - } - - foreach ((var _, var path) in paths) - { - RunFolder(path); - } - } - - public bool HasSources { get => sources.Count > 0; } - - private void AddSources(string folder) - { - foreach (var str in DirSearch(folder)) - { - string s = str.Replace("\\", "/"); - - if (sources.ContainsKey(folder)) - { - sources[folder].Add(s); - } - else - { - sources.Add(folder, new List { s }); - } - } - } - - private void RunFolder(string folder) - { - - AddSources(folder + "/CSharp/Shared"); - -#if SERVER - AddSources(folder + "/CSharp/Server"); -#else - AddSources(folder + "/CSharp/Client"); -#endif - } - - private IEnumerable ParseSources() { - var syntaxTrees = new List(); - - if (sources.Count <= 0) throw new Exception("No Cs sources detected"); - syntaxTrees.Add(AssemblyInfoSyntaxTree(CsScriptAssembly)); - foreach ((var folder, var src) in sources) - { - try - { - foreach (var file in src) - { - var tree = SyntaxFactory.ParseSyntaxTree(File.ReadAllText(file), ParseOptions, file); - - syntaxTrees.Add(tree); - } - } - catch (Exception ex) - { - LuaCsLogger.LogError("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace, LuaCsMessageOrigin.CSharpMod); - } - } - - return syntaxTrees; - } - - private ContentPackage FindSourcePackage(Diagnostic diagnostic) - { - if (diagnostic.Location.SourceTree == null) - { - return null; - } - - string path = diagnostic.Location.SourceTree.FilePath; - foreach (var package in ContentPackageManager.AllPackages) - { - if (Path.GetFullPath(path).StartsWith(Path.GetFullPath(package.Dir))) - { - return package; - } - } - - return null; - } - - public List Compile() - { - IEnumerable syntaxTrees = ParseSources(); - - var options = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary) - .WithMetadataImportOptions(MetadataImportOptions.All) - .WithOptimizationLevel(OptimizationLevel.Release) - .WithAllowUnsafe(true); - - var topLevelBinderFlagsProperty = typeof(CSharpCompilationOptions).GetProperty("TopLevelBinderFlags", BindingFlags.Instance | BindingFlags.NonPublic); - topLevelBinderFlagsProperty.SetValue(options, (uint)1 << 22); - - var compilation = CSharpCompilation.Create(CsScriptAssembly, syntaxTrees, defaultReferences, options); - - using (var mem = new MemoryStream()) - { - var result = compilation.Emit(mem); - if (!result.Success) - { - IEnumerable failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error); - - string errStr = "CS MODS NOT LOADED | Compilation errors:"; - foreach (Diagnostic diagnostic in failures) - { - errStr += $"\n{diagnostic}"; -#if CLIENT - ContentPackage package = FindSourcePackage(diagnostic); - if (package != null) - { - LuaCsLogger.ShowErrorOverlay($"{package.Name} {package.ModVersion} is causing compilation errors. Check debug console for more details.", 7f, 7f); - } -#endif - } - LuaCsLogger.LogError(errStr, LuaCsMessageOrigin.CSharpMod); - } - else - { - mem.Seek(0, SeekOrigin.Begin); - Assembly = LoadFromStream(mem); - } - } - - if (Assembly != null) - { - RegisterAssemblyWithNativeGame(Assembly); - try - { - return Assembly.GetTypes().Where(t => t.IsSubclassOf(typeof(ACsMod))).ToList(); - } - catch (ReflectionTypeLoadException re) - { - LuaCsLogger.LogError($"Unable to load CsMod Types. {re.Message}", LuaCsMessageOrigin.CSharpMod); - throw re; - } - } - else - { - throw new Exception("Unable to create cs mods assembly."); - } - } - - /// - /// This function should be used whenever a new assembly is created. Wrapper to allow more complicated setup later if need be. - /// - private static void RegisterAssemblyWithNativeGame(Assembly assembly) - { - Barotrauma.ReflectionUtils.AddNonAbstractAssemblyTypes(assembly); - } - - /// - /// This function should be used whenever a new assembly is about to be destroyed/unloaded. Wrapper to allow more complicated setup later if need be. - /// - /// Assembly to remove - private static void UnregisterAssemblyFromNativeGame(Assembly assembly) - { - Barotrauma.ReflectionUtils.RemoveAssemblyFromCache(assembly); - } - - private static string[] DirSearch(string sDir) - { - if (!Directory.Exists(sDir)) - { - return new string[] {}; - } - - return Directory.GetFiles(sDir, "*.cs", SearchOption.AllDirectories); - } - - public void Clear() - { - if (Assembly != null) - { - UnregisterAssemblyFromNativeGame(Assembly); - Assembly = null; - } - } - } -} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs index c0bfcba2bf..6c4c6b7e67 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs @@ -9,7 +9,7 @@ namespace Barotrauma { partial class LuaUserData { - public static Type GetType(string typeName) => LuaCsSetup.GetType(typeName); + public static Type GetType(string typeName) => LuaCsSetup.GetTypeRefCompat(typeName); public static IUserDataDescriptor RegisterType(string typeName) { diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs index 6c89d46802..290865b395 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs @@ -9,8 +9,8 @@ using System.Diagnostics; using MoonSharp.VsCodeDebugger; using System.Reflection; +using System.Runtime.Loader; -[assembly: InternalsVisibleTo(Barotrauma.CsScriptBase.CsScriptAssembly, AllInternalsVisible = true)] namespace Barotrauma { class LuaCsSetupConfig @@ -66,12 +66,18 @@ partial class LuaCsSetup public LuaCsSteam Steam { get; private set; } public LuaCsPerformanceCounter PerformanceCounter { get; private set; } + // must be available at anytime + private static AssemblyManager _assemblyManager; + public static AssemblyManager AssemblyManager => _assemblyManager ??= new AssemblyManager(); + + private CsPackageManager _pluginPackageManager; + public CsPackageManager PluginPackageManager => _pluginPackageManager ??= new CsPackageManager(AssemblyManager, this); + public LuaCsModStore ModStore { get; private set; } private LuaRequire require { get; set; } - - public CsScriptLoader CsScriptLoader { get; private set; } public LuaCsSetupConfig Config { get; private set; } public MoonSharpVsCodeDebugServer DebugServer { get; private set; } + public bool IsInitialized { get; private set; } private bool ShouldRunCs { @@ -90,7 +96,6 @@ public LuaCsSetup() Game = new LuaGame(); Networking = new LuaCsNetworking(); - DebugServer = new MoonSharpVsCodeDebugServer(); if (File.Exists(configFileName)) @@ -105,8 +110,18 @@ public LuaCsSetup() Config = new LuaCsSetupConfig(); } } - - public static Type GetType(string typeName, bool throwOnError = false, bool ignoreCase = false) + + + /// + /// Tries to get the type by name from all loaded assemblies via Type.GetType(). Intended to be + /// MoonSharp compatible with Generics and by-reference types. + /// NOTE: To get the by-reference type, prefix the type name with "out " or "ref ". + /// + /// The fully-qualified type name. + /// Whether or not to throw an error if no type is found. + /// Whether or not to use a case-sensitive search. + /// The type if found, null if not. + public static Type GetTypeRefCompat(string typeName, bool throwOnError = false, bool ignoreCase = false) { if (typeName == null || typeName.Length == 0) { return null; } @@ -121,12 +136,6 @@ public static Type GetType(string typeName, bool throwOnError = false, bool igno if (type != null) { return byRef ? type.MakeByRefType() : type; } foreach (var a in AppDomain.CurrentDomain.GetAssemblies()) { - if (CsScriptBase.LoadedAssemblyName.Contains(a.GetName().Name)) - { - var attrs = a.GetCustomAttributes(); - var revision = attrs.FirstOrDefault(attr => attr.Key == "Revision")?.Value; - if (revision != null && int.Parse(revision) != (int)CsScriptBase.Revision[a.GetName().Name]) { continue; } - } type = a.GetType(typeName, throwOnError, ignoreCase); if (type != null) { @@ -293,17 +302,21 @@ public void Update() public void Stop() { - foreach (var type in AppDomain.CurrentDomain.GetAssemblies().Where(a => a.GetName().Name == CsScriptBase.CsScriptAssembly).SelectMany(assembly => assembly.GetTypes())) + + // unregister types + foreach (Type type in AssemblyManager.GetAllLoadedACLs().SelectMany( + acl => acl.AssembliesTypes.Select(kvp => kvp.Value))) { UserData.UnregisterType(type, true); } + + PluginPackageManager.UnloadPlugins(); // stop plugin code execution - foreach (var mod in ACsMod.LoadedMods.ToArray()) + if (Lua?.Globals is not null) { - mod.Dispose(); + Lua.Globals.Remove("CsPackageManager"); + Lua.Globals.Remove("AssemblyManager"); } - - ACsMod.LoadedMods.Clear(); if (Thread.CurrentThread == GameMain.MainThread) { @@ -317,27 +330,27 @@ public void Stop() Game?.Stop(); - Hook.Clear(); + Hook?.Clear(); ModStore.Clear(); + LuaScriptLoader = null; + Lua = null; + + // we can only unload assemblies after clearing ModStore/references. + PluginPackageManager.Dispose(); + Game = new LuaGame(); Networking = new LuaCsNetworking(); Timer = new LuaCsTimer(); Steam = new LuaCsSteam(); PerformanceCounter = new LuaCsPerformanceCounter(); - LuaScriptLoader = null; - Lua = null; - if (CsScriptLoader != null) - { - CsScriptLoader.Clear(); - CsScriptLoader.Unload(); - CsScriptLoader = null; - } + IsInitialized = false; } public void Initialize(bool forceEnableCs = false) { - Stop(); + if (IsInitialized) + Stop(); LuaCsLogger.LogMessage("Lua! Version " + AssemblyInfo.GitRevision); @@ -380,6 +393,9 @@ public void Initialize(bool forceEnableCs = false) UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); + UserData.RegisterType(); + UserData.RegisterType(); + UserData.RegisterType(); UserData.RegisterExtensionType(typeof(MathUtils)); UserData.RegisterExtensionType(typeof(XMLExtensions)); @@ -430,65 +446,80 @@ public void Initialize(bool forceEnableCs = false) DebugConsole.AddWarning("Cs package active! Cs mods are NOT sandboxed, use it at your own risk!"); } - CsScriptLoader = new CsScriptLoader(); - CsScriptLoader.SearchFolders(); - if (CsScriptLoader.HasSources) + Lua.Globals["PluginPackageManager"] = PluginPackageManager; + Lua.Globals["AssemblyManager"] = AssemblyManager; + + try { - try + Stopwatch taskTimer = new(); + taskTimer.Start(); + ModStore.Clear(); + + var state = PluginPackageManager.LoadAssemblyPackages(); + if (state is AssemblyLoadingSuccessState.Success or AssemblyLoadingSuccessState.AlreadyLoaded) { - Stopwatch compilationTime = new Stopwatch(); - compilationTime.Start(); - var modTypes = CsScriptLoader.Compile(); - - modTypes.ForEach(t => - { - t.GetConstructor(new Type[] { })?.Invoke(null); - }); - - compilationTime.Stop(); - LuaCsLogger.LogMessage($"Took {compilationTime.ElapsedMilliseconds}ms to compile and run Cs Scripts."); + if(!PluginPackageManager.PluginsInitialized) + PluginPackageManager.InstantiatePlugins(true); + if(!PluginPackageManager.PluginsPreInit) + PluginPackageManager.RunPluginsPreInit(); // this is intended to be called at startup in the future + if(!PluginPackageManager.PluginsLoaded) + PluginPackageManager.RunPluginsInit(); + state = AssemblyLoadingSuccessState.Success; + taskTimer.Stop(); + ModUtils.Logging.PrintMessage($"{nameof(LuaCsSetup)}: Completed assembly loading. Total time {taskTimer.ElapsedMilliseconds}ms."); + } + else + { + PluginPackageManager.Dispose(); // cleanup if there's an error } - catch (Exception ex) + + if(state is not AssemblyLoadingSuccessState.Success) { - LuaCsLogger.HandleException(ex, LuaCsMessageOrigin.CSharpMod); + ModUtils.Logging.PrintError($"{nameof(LuaCsSetup)}: Error while loading Cs-Assembly Mods | Err: {state}"); + taskTimer.Stop(); } } + catch (Exception e) + { + ModUtils.Logging.PrintError($"{nameof(LuaCsSetup)}::{nameof(Initialize)}() | Error while loading assemblies! Details: {e.Message} | {e.StackTrace}"); + } + IsInitialized = true; } ContentPackage luaPackage = GetPackage(LuaForBarotraumaId); - void runLocal() + void RunLocal() { LuaCsLogger.LogMessage("Using LuaSetup.lua from the Barotrauma Lua/ folder."); string luaPath = LuaSetupFile; CallLuaFunction(Lua.LoadFile(luaPath), Path.GetDirectoryName(Path.GetFullPath(luaPath))); } - void runWorkshop() + void RunWorkshop() { LuaCsLogger.LogMessage("Using LuaSetup.lua from the content package."); string luaPath = Path.Combine(Path.GetDirectoryName(luaPackage.Path), "Binary/Lua/LuaSetup.lua"); CallLuaFunction(Lua.LoadFile(luaPath), Path.GetDirectoryName(Path.GetFullPath(luaPath))); } - void runNone() + void RunNone() { LuaCsLogger.LogError("LuaSetup.lua not found! Lua/LuaSetup.lua, no Lua scripts will be executed or work.", LuaCsMessageOrigin.LuaMod); } if (Config.PreferToUseWorkshopLuaSetup) { - if (luaPackage != null) { runWorkshop(); } - else if (File.Exists(LuaSetupFile)) { runLocal(); } - else { runNone(); } + if (luaPackage != null) { RunWorkshop(); } + else if (File.Exists(LuaSetupFile)) { RunLocal(); } + else { RunNone(); } } else { - if (File.Exists(LuaSetupFile)) { runLocal(); } - else if (luaPackage != null) { runWorkshop(); } - else { runNone(); } + if (File.Exists(LuaSetupFile)) { RunLocal(); } + else if (luaPackage != null) { RunWorkshop(); } + else { RunNone(); } } executionNumber++; diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs index 8501004cd2..14a7685ef4 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs @@ -4,6 +4,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Collections.Immutable; using System.Diagnostics; using System.IO; using System.Linq; @@ -259,14 +260,22 @@ private enum ValueType private static Type[] LoadDocTypes(XElement typesElem) { var result = new List(); + var loadedTypes = LuaCsSetup.AssemblyManager + .GetAllTypesInLoadedAssemblies() + .ToImmutableHashSet(); + foreach (var elem in typesElem.Elements()) { - var type = Type.GetType(elem.Value); - if (type == null && GameMain.LuaCs?.CsScriptLoader?.Assembly != null) type = GameMain.LuaCs.CsScriptLoader.Assembly.GetType(elem.Value); - if (type == null) throw new Exception($"Type {elem.Value} not found."); - result.Add(type); - + var typesFound = loadedTypes.Where(t => t.FullName?.EndsWith(elem.Value) ?? false).ToImmutableList(); + if (!typesFound.Any()) + { + ModUtils.Logging.PrintError( + $"{nameof(LuaCsConfig)}::{nameof(LoadDocTypes)}() | Unable to find a matching type for {elem.Value}"); + continue; + } + result.AddRange(typesFound); } + return result.ToArray(); } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/ModUtils.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/ModUtils.cs new file mode 100644 index 0000000000..1ec321ad28 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/ModUtils.cs @@ -0,0 +1,332 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using System.Xml.Serialization; +using Barotrauma; +using Barotrauma.Items.Components; +using Barotrauma.Networking; +using Microsoft.CodeAnalysis; + +namespace Barotrauma; + +public static class ModUtils +{ + #region LOGGING + + public static class Logging + { + public static void PrintMessage(string s) + { +#if SERVER + LuaCsLogger.LogMessage($"[Server] {s}"); +#else + LuaCsLogger.LogMessage($"[Client] {s}"); +#endif + } + + public static void PrintError(string s) + { +#if SERVER + LuaCsLogger.LogError($"[Server] {s}"); +#else + LuaCsLogger.LogError($"[Client] {s}"); +#endif + } + } + + #endregion + + #region FILE_IO + + // ReSharper disable once InconsistentNaming + public static class IO + { + public static IEnumerable FindAllFilesInDirectory(string folder, string pattern, + SearchOption option) + { + try + { + return Directory.GetFiles(folder, pattern, option); + } + catch (DirectoryNotFoundException e) + { + return new string[] { }; + } + } + + public static string PrepareFilePathString(string filePath) => + PrepareFilePathString(Path.GetDirectoryName(filePath)!, Path.GetFileName(filePath)); + + public static string PrepareFilePathString(string path, string fileName) => + Path.Combine(SanitizePath(path), SanitizeFileName(fileName)); + + public static string SanitizeFileName(string fileName) + { + foreach (char c in Barotrauma.IO.Path.GetInvalidFileNameCharsCrossPlatform()) + fileName = fileName.Replace(c, '_'); + return fileName; + } + + /// + /// Gets the sanitized path for the top-level directory for a given content package. + /// + /// The target package. + /// The fully-qualified path, sanitized. + public static string GetContentPackageDir(ContentPackage package) + { + return SanitizePath(Path.GetFullPath(package.Dir)); + } + + public static string SanitizePath(string path) + { + foreach (char c in Path.GetInvalidPathChars()) + path = path.Replace(c.ToString(), "_"); + return path.CleanUpPath(); + } + + public static IOActionResultState GetOrCreateFileText(string filePath, out string fileText, Func fileDataFactory = null, bool createFile = true) + { + fileText = null; + string fp = Path.GetFullPath(SanitizePath(filePath)); + + IOActionResultState ioActionResultState = IOActionResultState.Success; + if (createFile) + { + ioActionResultState = CreateFilePath(SanitizePath(filePath), out fp, fileDataFactory); + } + else if (!File.Exists(fp)) + { + return IOActionResultState.FileNotFound; + } + + if (ioActionResultState == IOActionResultState.Success) + { + try + { + fileText = File.ReadAllText(fp!); + return IOActionResultState.Success; + } + catch (ArgumentNullException ane) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: An argument is null. path: {fp ?? "null"} | Exception Details: {ane.Message}"); + return IOActionResultState.FilePathNull; + } + catch (ArgumentException ae) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: An argument is invalid. path: {fp ?? "null"} | Exception Details: {ae.Message}"); + return IOActionResultState.FilePathInvalid; + } + catch (DirectoryNotFoundException dnfe) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Cannot find directory. path: {fp ?? "null"} | Exception Details: {dnfe.Message}"); + return IOActionResultState.DirectoryMissing; + } + catch (PathTooLongException ptle) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: path length is over 200 characters. path: {fp ?? "null"} | Exception Details: {ptle.Message}"); + return IOActionResultState.PathTooLong; + } + catch (NotSupportedException nse) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Operation not supported on your platform/environment (permissions?). path: {fp ?? "null"} | Exception Details: {nse.Message}"); + return IOActionResultState.InvalidOperation; + } + catch (IOException ioe) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: IO tasks failed (Operation not supported). path: {fp ?? "null"} | Exception Details: {ioe.Message}"); + return IOActionResultState.IOFailure; + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Unknown/Other Exception. path: {fp ?? "null"} | ExceptionMessage: {e.Message}"); + return IOActionResultState.UnknownError; + } + } + + return ioActionResultState; + } + + public static IOActionResultState CreateFilePath(string filePath, out string formattedFilePath, Func fileDataFactory = null) + { + string file = Path.GetFileName(filePath); + string path = Path.GetDirectoryName(filePath)!; + + formattedFilePath = IO.PrepareFilePathString(path, file); + try + { + if (!Directory.Exists(path)) + Directory.CreateDirectory(path); + if (!File.Exists(formattedFilePath)) + File.WriteAllText(formattedFilePath, fileDataFactory is null ? "" : fileDataFactory.Invoke()); + return IOActionResultState.Success; + } + catch (ArgumentNullException ane) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: An argument is null. path: {formattedFilePath ?? "null"} | Exception Details: {ane.Message}"); + return IOActionResultState.FilePathNull; + } + catch (ArgumentException ae) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: An argument is invalid. path: {formattedFilePath ?? "null"} | Exception Details: {ae.Message}"); + return IOActionResultState.FilePathInvalid; + } + catch (DirectoryNotFoundException dnfe) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Cannot find directory. path: {path ?? "null"} | Exception Details: {dnfe.Message}"); + return IOActionResultState.DirectoryMissing; + } + catch (PathTooLongException ptle) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: path length is over 200 characters. path: {formattedFilePath ?? "null"} | Exception Details: {ptle.Message}"); + return IOActionResultState.PathTooLong; + } + catch (NotSupportedException nse) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Operation not supported on your platform/environment (permissions?). path: {formattedFilePath ?? "null"} | Exception Details: {nse.Message}"); + return IOActionResultState.InvalidOperation; + } + catch (IOException ioe) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: IO tasks failed (Operation not supported). path: {formattedFilePath ?? "null"} | Exception Details: {ioe.Message}"); + return IOActionResultState.IOFailure; + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Unknown/Other Exception. path: {path ?? "null"} | Exception Details: {e.Message}"); + return IOActionResultState.UnknownError; + } + } + + public static IOActionResultState WriteFileText(string filePath, string fileText) + { + IOActionResultState ioActionResultState = CreateFilePath(filePath, out var fp); + if (ioActionResultState == IOActionResultState.Success) + { + try + { + File.WriteAllText(fp!, fileText); + return IOActionResultState.Success; + } + catch (ArgumentNullException ane) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: An argument is null. path: {fp ?? "null"} | Exception Details: {ane.Message}"); + return IOActionResultState.FilePathNull; + } + catch (ArgumentException ae) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: An argument is invalid. path: {fp ?? "null"} | Exception Details: {ae.Message}"); + return IOActionResultState.FilePathInvalid; + } + catch (DirectoryNotFoundException dnfe) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: Cannot find directory. path: {fp ?? "null"} | Exception Details: {dnfe.Message}"); + return IOActionResultState.DirectoryMissing; + } + catch (PathTooLongException ptle) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: path length is over 200 characters. path: {fp ?? "null"} | Exception Details: {ptle.Message}"); + return IOActionResultState.PathTooLong; + } + catch (NotSupportedException nse) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: Operation not supported on your platform/environment (permissions?). path: {fp ?? "null"} | Exception Details: {nse.Message}"); + return IOActionResultState.InvalidOperation; + } + catch (IOException ioe) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: IO tasks failed (Operation not supported). path: {fp ?? "null"} | Exception Details: {ioe.Message}"); + return IOActionResultState.IOFailure; + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: Unknown/Other Exception. path: {fp ?? "null"} | ExceptionMessage: {e.Message}"); + return IOActionResultState.UnknownError; + } + } + + return ioActionResultState; + } + + /// + /// Uses the XmlSerializer to try and load the data for a given type from the file. Optionally, will create a new + /// instance of the type from the supplied factory method and save the data to the file. + /// + /// The instance of the type specified. + /// The full path to the file. + /// Factory method to produce a default version of the type. + /// Whether or not to try and create a new file and save the data to disk on file not found. + /// The type to be loaded or created. + /// Operation success. + public static bool LoadOrCreateTypeXml(out T instance, + string filepath, Func typeFactory = null, bool createFile = true) where T : class, new() + { + instance = null; + filepath = filepath.CleanUpPath(); + if (IOActionResultState.Success == GetOrCreateFileText( + filepath, out string fileText, typeFactory is not null ? () => + { + using StringWriter sw = new StringWriter(); + T t = typeFactory?.Invoke(); + if (t is not null) + { + XmlSerializer s = new XmlSerializer(typeof(T)); + s.Serialize(sw, t); + return sw.ToString(); + } + return ""; + } : null, createFile)) + { + XmlSerializer s = new XmlSerializer(typeof(T)); + try + { + using TextReader tr = new StringReader(fileText); + instance = (T)s.Deserialize(tr); + return true; + } + catch(InvalidOperationException ioe) + { + ModUtils.Logging.PrintError($"Error while parsing type data for {typeof(T)}."); + #if DEBUG + ModUtils.Logging.PrintError($"Exception: {ioe.Message}. Details: {ioe.InnerException?.Message}"); + #endif + instance = null; + return false; + } + } + + return false; + } + + public enum IOActionResultState + { + Success, FileNotFound, FilePathNull, FilePathInvalid, DirectoryMissing, PathTooLong, InvalidOperation, IOFailure, UnknownError + } + } + + #endregion + + #region GAME + + public static class Game + { + /// + /// Returns whether or not there is a round running. + /// + /// + public static bool IsRoundInProgress() + { +#if CLIENT + if (Screen.Selected is not null + && Screen.Selected.IsEditor) + return false; +#endif + return GameMain.GameSession is not null && Level.Loaded is not null; + } + + } + + #endregion +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/ACsMod.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ACsMod.cs similarity index 60% rename from Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/ACsMod.cs rename to Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ACsMod.cs index 60ef6551d5..76dfac73f2 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/ACsMod.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ACsMod.cs @@ -1,11 +1,11 @@ using System; using System.Collections.Generic; using System.IO; -using System.Reflection; namespace Barotrauma { - public abstract class ACsMod : IDisposable + [Obsolete("Make your class implement IAssemblyPlugin instead.")] + public abstract class ACsMod : IAssemblyPlugin { private static List mods = new List(); public static List LoadedMods { get => mods; } @@ -18,7 +18,6 @@ public static string GetStoreFolder() where T : ACsMod if (!Directory.Exists(modFolder)) Directory.CreateDirectory(modFolder); return modFolder; } - public static string GetSoreFolder() where T : ACsMod => GetStoreFolder(); public bool IsDisposed { get; private set; } @@ -29,7 +28,23 @@ public ACsMod() LoadedMods.Add(this); } - public void Dispose() + /// + /// Called as soon as plugin loading begins, use this for internal setup only. + /// + public virtual void Initialize() { } + + /// + /// Called once all plugins have completed Initialization. Put cross-mod code here. + /// + public virtual void OnLoadCompleted() { } + + /// + /// [NotImplemented] Called before vanilla content is loaded. Use to patch Barotrauma classes before they're + /// instantiated. + /// + public void PreInitPatching() { } + + public virtual void Dispose() { try { @@ -43,8 +58,7 @@ public void Dispose() LoadedMods.Remove(this); IsDisposed = true; } - - /// Error or client exit + public abstract void Stop(); } } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ApplicationMode.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ApplicationMode.cs new file mode 100644 index 0000000000..6e60184bb7 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/ApplicationMode.cs @@ -0,0 +1,6 @@ +namespace Barotrauma; + +public enum ApplicationMode +{ + Client, Server +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyLoadingSuccessState.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyLoadingSuccessState.cs new file mode 100644 index 0000000000..e55821eb3d --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyLoadingSuccessState.cs @@ -0,0 +1,15 @@ +namespace Barotrauma; + +public enum AssemblyLoadingSuccessState +{ + ACLLoadFailure, + AlreadyLoaded, + BadFilePath, + CannotLoadFile, + InvalidAssembly, + NoAssemblyFound, + PluginInstanceFailure, + BadName, + CannotLoadFromStream, + Success +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyManager.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyManager.cs new file mode 100644 index 0000000000..e2d5ad2d0f --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/AssemblyManager.cs @@ -0,0 +1,770 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.Loader; +using System.Threading; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; + +// ReSharper disable EventNeverSubscribedTo.Global +// ReSharper disable InconsistentNaming + +namespace Barotrauma; + +/*** + * Note: This class was written to be thread-safe in order to allow parallelization in loading in the future if the need + * becomes necessary as there is almost no serial performance overhead for adding threading protection. + */ + +/// +/// Provides functionality for the loading, unloading and management of plugins implementing IAssemblyPlugin. +/// All plugins are loaded into their own AssemblyLoadContext along with their dependencies. +/// +public partial class AssemblyManager +{ + #region ExternalAPI + + /// + /// Called when an assembly is loaded. + /// + public event Action OnAssemblyLoaded; + + /// + /// Called when an assembly is marked for unloading, before unloading begins. You should use this to cleanup + /// any references that you have to this assembly. + /// + public event Action OnAssemblyUnloading; + + /// + /// Called whenever an exception is thrown. First arg is a formatted message, Second arg is the Exception. + /// + public event Action OnException; + + /// + /// For unloading issue debugging. Called whenever MemoryFileAssemblyContextLoader [load context] is unloaded. + /// + public event Action OnACLUnload; + + #if DEBUG + + /// + /// [DEBUG ONLY] + /// Returns a list of the current unloading ACLs. + /// + public ImmutableList> StillUnloadingACLs + { + get + { + OpsLockUnloaded.EnterReadLock(); + try + { + return UnloadingACLs.ToImmutableList(); + } + finally + { + OpsLockUnloaded.ExitReadLock(); + } + } + } + + #endif + + + // ReSharper disable once MemberCanBePrivate.Global + /// + /// Checks if there are any AssemblyLoadContexts still in the process of unloading. + /// + public bool IsCurrentlyUnloading + { + get + { + OpsLockUnloaded.EnterReadLock(); + try + { + return UnloadingACLs.Any(); + } + catch (Exception) + { + return false; + } + finally + { + OpsLockUnloaded.ExitReadLock(); + } + } + } + + // Old API compatibility + public IEnumerable GetSubTypesInLoadedAssemblies() + { + return GetSubTypesInLoadedAssemblies(false); + } + + + /// + /// Allows iteration over all non-interface types in all loaded assemblies in the AsmMgr that are assignable to the given type (IsAssignableFrom). + /// Warning: care should be used when using this method in hot paths as performance may be affected. + /// + /// The type to compare against + /// Forces caches to clear and for the lists of types to be rebuilt. + /// An Enumerator for matching types. + public IEnumerable GetSubTypesInLoadedAssemblies(bool rebuildList) + { + Type targetType = typeof(T); + string typeName = targetType.FullName ?? targetType.Name; + + // rebuild + if (rebuildList) + RebuildTypesList(); + + // check cache + if (_subTypesLookupCache.TryGetValue(typeName, out var subTypeList)) + { + return subTypeList; + } + + // build from scratch + OpsLockLoaded.EnterReadLock(); + try + { + // build list + var list1 = _defaultContextTypes + .Where(kvp1 => targetType.IsAssignableFrom(kvp1.Value) && !kvp1.Value.IsInterface) + .Concat(LoadedACLs + .SelectMany(kvp => kvp.Value.AssembliesTypes) + .Where(kvp2 => targetType.IsAssignableFrom(kvp2.Value) && !kvp2.Value.IsInterface)) + .Select(kvp3 => kvp3.Value) + .ToImmutableList(); + + // only add if we find something + if (list1.Count > 0) + { + if (!_subTypesLookupCache.TryAdd(typeName, list1)) + { + ModUtils.Logging.PrintError($"{nameof(AssemblyManager)}: Unable to add subtypes to cache of type {typeName}!"); + } + } + else + { + ModUtils.Logging.PrintMessage($"{nameof(AssemblyManager)}: Warning: No types found during search for subtypes of {typeName}"); + } + + return list1; + } + finally + { + OpsLockLoaded.ExitReadLock(); + } + } + + /// + /// Tries to get types assignable to type from the ACL given the Guid. + /// + /// + /// + /// + /// + public bool TryGetSubTypesFromACL(Guid id, out IEnumerable types) + { + Type targetType = typeof(T); + + if (TryGetACL(id, out var acl)) + { + types = acl.AssembliesTypes + .Where(kvp => targetType.IsAssignableFrom(kvp.Value) && !kvp.Value.IsInterface) + .Select(kvp => kvp.Value); + return true; + } + + types = null; + return false; + } + + /// + /// Tries to get types from the ACL given the Guid. + /// + /// + /// + /// + /// + public bool TryGetSubTypesFromACL(Guid id, out IEnumerable types) + { + if (TryGetACL(id, out var acl)) + { + types = acl.AssembliesTypes.Select(kvp => kvp.Value); + return true; + } + + types = null; + return false; + } + + + /// + /// Allows iteration over all types, including interfaces, in all loaded assemblies in the AsmMgr who's names match the string. + /// + /// The string name of the type to search for. + /// An Enumerator for matching types. + public IEnumerable GetTypesByName(string typeName, bool rebuildOnFail = false) + { + List types = new(); + + TypesListHelper(); + if (types.Count > 0) + return types; + + if (rebuildOnFail) + { + // we couldn't find it, rebuild and try one more time + RebuildTypesList(); + TypesListHelper(); + } + return types; + + void TypesListHelper() + { + if (_defaultContextTypes.TryGetValue(typeName, out var type1)) + { + if (type1 is not null) + types.Add(type1); + } + + OpsLockLoaded.EnterReadLock(); + try + { + foreach (KeyValuePair loadedAcl in LoadedACLs) + { + var at = loadedAcl.Value.AssembliesTypes; + if (at.TryGetValue(typeName, out var type2)) + { + if (type2 is not null) + types.Add(type2); + } + } + } + finally + { + OpsLockLoaded.ExitReadLock(); + } + } + } + + /// + /// Allows iteration over all types (including interfaces) in all loaded assemblies managed by the AsmMgr. + /// Warning: High usage may result in performance issues. + /// + /// An Enumerator for iteration. + public IEnumerable GetAllTypesInLoadedAssemblies() + { + OpsLockLoaded.EnterReadLock(); + try + { + return AssemblyLoadContext.Default.Assemblies + .SelectMany(a => a.GetSafeTypes()) + .Concat(LoadedACLs + .SelectMany(kvp => kvp.Value.AssembliesTypes.Select(kv => kv.Value))) + .ToImmutableList(); + } + finally + { + OpsLockLoaded.ExitReadLock(); + } + } + + /// + /// Returns a list of all loaded ACLs. + /// WARNING: References to these ACLs outside of the AssemblyManager should be kept in a WeakReference in order + /// to avoid causing issues with unloading/disposal. + /// + /// + public IEnumerable GetAllLoadedACLs() + { + try + { + OpsLockLoaded.EnterReadLock(); + return LoadedACLs.Select(kvp => kvp.Value).ToImmutableList(); + } + finally + { + OpsLockLoaded.ExitReadLock(); + } + + } + + #endregion + + #region InternalAPI + + /// + /// Used by content package and plugin management to stop unloading of a given ACL until all plugins have gracefully closed. + /// + public event System.Func IsReadyToUnloadACL; + + public AssemblyLoadingSuccessState LoadAssemblyFromMemory([NotNull] string compiledAssemblyName, + [NotNull] IEnumerable syntaxTree, + IEnumerable externalMetadataReferences, + [NotNull] CSharpCompilationOptions compilationOptions, + ref Guid id, + IEnumerable externFileAssemblyRefs = null) + { + // validation + if (compiledAssemblyName.IsNullOrWhiteSpace()) + return AssemblyLoadingSuccessState.BadName; + + if (!GetOrCreateACL(id, out var acl)) + return AssemblyLoadingSuccessState.ACLLoadFailure; + + id = acl.Id; // pass on true id returned + + // this acl is already hosting an in-memory assembly + if (acl.Acl.CompiledAssembly is not null) + return AssemblyLoadingSuccessState.AlreadyLoaded; + + // compile + var state = acl.Acl.CompileAndLoadScriptAssembly(compiledAssemblyName, syntaxTree, externalMetadataReferences, + compilationOptions, out var messages, externFileAssemblyRefs); + + // get types + if (state is AssemblyLoadingSuccessState.Success) + { + _subTypesLookupCache.Clear(); + acl.RebuildTypesList(); + OnAssemblyLoaded?.Invoke(acl.Acl.CompiledAssembly); + } + else + { + ModUtils.Logging.PrintError($"Unable to compile assembly '{compiledAssemblyName}' due to errors: {messages}"); + } + + return state; + } + + /// + /// Switches the ACL with the given Guid to Template Mode, which disables assembly name resolution for any assemblies loaded in it. + /// These ACLs are intended to be used to host Assemblies for information only and not for code execution. + /// WARNING: This process is irreversible. + /// + /// Guid of the ACL. + /// Whether or not an ACL was found with the given ID. + public bool SetACLToTemplateMode(Guid guid) + { + if (!TryGetACL(guid, out var acl)) + return false; + acl.Acl.IsTemplateMode = true; + return true; + } + + /// + /// Tries to load all assemblies at the supplied file paths list into the ACl with the given Guid. + /// If the supplied Guid is Empty, then a new ACl will be created and the Guid will be assigned to it. + /// + /// List of assemblies to try and load. + /// Guid of the ACL or Empty if none specified. Guid of ACL will be assigned to this var. + /// Operation success messages. + /// + public AssemblyLoadingSuccessState LoadAssembliesFromLocations([NotNull] IEnumerable filePaths, + ref Guid id) + { + + if (filePaths is null) + { + throw new ArgumentNullException( + $"{nameof(AssemblyManager)}::{nameof(LoadAssembliesFromLocations)}() | file paths supplied is null!"); + } + + ImmutableList assemblyFilePaths = filePaths.ToImmutableList(); // copy the list before loading + + if (!assemblyFilePaths.Any()) + { + return AssemblyLoadingSuccessState.NoAssemblyFound; + } + + if (GetOrCreateACL(id, out var loadedAcl)) + { + var state = loadedAcl.Acl.LoadFromFiles(assemblyFilePaths); + // if failure, we dispose of the acl + if (state != AssemblyLoadingSuccessState.Success) + { + DisposeACL(loadedAcl.Id); + ModUtils.Logging.PrintError($"ACL failed, unloading..."); + return state; + } + // build types list + _subTypesLookupCache.Clear(); + loadedAcl.RebuildTypesList(); + id = loadedAcl.Id; + foreach (Assembly assembly in loadedAcl.Acl.Assemblies) + { + OnAssemblyLoaded?.Invoke(assembly); + } + return state; + } + + return AssemblyLoadingSuccessState.ACLLoadFailure; + } + + + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Synchronized)] + public bool TryBeginDispose() + { + OpsLockLoaded.EnterWriteLock(); + OpsLockUnloaded.EnterWriteLock(); + try + { + _subTypesLookupCache.Clear(); + + foreach (KeyValuePair loadedAcl in LoadedACLs) + { + if (loadedAcl.Value.Acl is not null) + { + if (IsReadyToUnloadACL is not null) + { + foreach (Delegate del in IsReadyToUnloadACL.GetInvocationList()) + { + if (del is System.Func { } func) + { + if (!func.Invoke(loadedAcl.Value)) + return false; // Not ready, exit + } + } + } + + foreach (Assembly assembly in loadedAcl.Value.Acl.Assemblies) + { + OnAssemblyUnloading?.Invoke(assembly); + } + + UnloadingACLs.Add(new WeakReference(loadedAcl.Value.Acl, true)); + loadedAcl.Value.ClearTypesList(); + loadedAcl.Value.Acl.Unload(); + OnACLUnload?.Invoke(loadedAcl.Value.Id); + } + } + + LoadedACLs.Clear(); + return true; + } + catch + { + // should never happen + return false; + } + finally + { + OpsLockUnloaded.ExitWriteLock(); + OpsLockLoaded.ExitWriteLock(); + } + } + + + [MethodImpl(MethodImplOptions.NoInlining)] + public bool FinalizeDispose() + { + bool isUnloaded; + OpsLockUnloaded.EnterUpgradeableReadLock(); + try + { + List> toRemove = new(); + foreach (WeakReference weakReference in UnloadingACLs) + { + if (!weakReference.TryGetTarget(out _)) + { + toRemove.Add(weakReference); + } + } + + if (toRemove.Any()) + { + OpsLockUnloaded.EnterWriteLock(); + try + { + foreach (WeakReference reference in toRemove) + { + UnloadingACLs.Remove(reference); + } + } + finally + { + OpsLockUnloaded.ExitWriteLock(); + } + } + isUnloaded = !UnloadingACLs.Any(); + } + finally + { + OpsLockUnloaded.ExitUpgradeableReadLock(); + } + + return isUnloaded; + } + + /// + /// Tries to retrieve the LoadedACL with the given ID or null if none is found. + /// WARNING: External references to this ACL with long lifespans should be kept in a WeakReference + /// to avoid causing unloading/disposal issues. + /// + /// GUID of the ACL. + /// The found ACL or null if none was found. + /// Whether or not an ACL was found. + [MethodImpl(MethodImplOptions.NoInlining)] + public bool TryGetACL(Guid id, out LoadedACL acl) + { + acl = null; + OpsLockLoaded.EnterReadLock(); + try + { + if (id.Equals(Guid.Empty) || !LoadedACLs.ContainsKey(id)) + return false; + acl = LoadedACLs[id]; + return true; + } + finally + { + OpsLockLoaded.ExitReadLock(); + } + } + + + /// + /// Gets or creates an AssemblyCtxLoader for the given ID. Creates if the ID is empty or no ACL can be found. + /// [IMPORTANT] After calling this method, the id you use should be taken from the acl container (acl.Id). + /// + /// + /// + /// Should only return false if an error occurs. + [MethodImpl(MethodImplOptions.NoInlining)] + private bool GetOrCreateACL(Guid id, out LoadedACL acl) + { + OpsLockLoaded.EnterUpgradeableReadLock(); + try + { + if (id.Equals(Guid.Empty) || !LoadedACLs.ContainsKey(id) || LoadedACLs[id] is null) + { + OpsLockLoaded.EnterWriteLock(); + try + { + id = Guid.NewGuid(); + acl = new LoadedACL(id, this); + LoadedACLs[id] = acl; + return true; + } + finally + { + OpsLockLoaded.ExitWriteLock(); + } + } + else + { + acl = LoadedACLs[id]; + return true; + } + + } + catch + { + // should never happen but in-case + acl = null; + return false; + } + finally + { + OpsLockLoaded.ExitUpgradeableReadLock(); + } + } + + + [MethodImpl(MethodImplOptions.NoInlining)] + private bool DisposeACL(Guid id) + { + OpsLockLoaded.EnterWriteLock(); + OpsLockUnloaded.EnterWriteLock(); + try + { + if (id.Equals(Guid.Empty) || !LoadedACLs.ContainsKey(id) || LoadedACLs[id] is null) + { + return false; // nothing to dispose of + } + + var acl = LoadedACLs[id]; + + foreach (Assembly assembly in acl.Acl.Assemblies) + { + OnAssemblyUnloading?.Invoke(assembly); + } + + _subTypesLookupCache.Clear(); + UnloadingACLs.Add(new WeakReference(acl.Acl, true)); + acl.Acl.Unload(); + OnACLUnload?.Invoke(acl.Id); + + return true; + } + catch + { + // should never happen + return false; + } + finally + { + OpsLockLoaded.ExitWriteLock(); + OpsLockUnloaded.ExitWriteLock(); + } + } + + internal AssemblyManager() + { + RebuildTypesList(); + } + + /// + /// Rebuilds the list of types in the default assembly load context. + /// + private void RebuildTypesList() + { + try + { + _defaultContextTypes = AssemblyLoadContext.Default.Assemblies + .SelectMany(a => a.GetSafeTypes()) + .ToImmutableDictionary(t => t.FullName ?? t.Name, t => t); + _subTypesLookupCache.Clear(); + } + catch(ArgumentException _) + { + try + { + // some types must've had duplicate type names, build the list while filtering + Dictionary types = new(); + foreach (var type in AssemblyLoadContext.Default.Assemblies.SelectMany(a => a.GetSafeTypes())) + { + try + { + types.TryAdd(type.FullName ?? type.Name, type); + } + catch + { + // ignore, null key exception + } + } + + _defaultContextTypes = types.ToImmutableDictionary(); + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"{nameof(AssemblyManager)}: Unable to create list of default assembly types! Default AssemblyLoadContext types searching not available."); +#if DEBUG + ModUtils.Logging.PrintError($"{nameof(AssemblyManager)}: Exception Details :{e.Message} | {e.InnerException}"); +#endif + _defaultContextTypes = ImmutableDictionary.Empty; + } + } + } + + #endregion + + #region Data + + private readonly ConcurrentDictionary> _subTypesLookupCache = new(); + private ImmutableDictionary _defaultContextTypes; + private readonly ConcurrentDictionary LoadedACLs = new(); + private readonly List> UnloadingACLs= new(); + private readonly ReaderWriterLockSlim OpsLockLoaded = new ReaderWriterLockSlim(); + private readonly ReaderWriterLockSlim OpsLockUnloaded = new ReaderWriterLockSlim(); + + #endregion + + #region TypeDefs + + + public sealed class LoadedACL + { + public readonly Guid Id; + private ImmutableDictionary _assembliesTypes = ImmutableDictionary.Empty; + public readonly MemoryFileAssemblyContextLoader Acl; + private readonly AssemblyManager _manager; + + internal LoadedACL(Guid id, AssemblyManager manager) + { + this.Id = id; + this.Acl = new(manager); + this._manager = manager; + } + public ImmutableDictionary AssembliesTypes => _assembliesTypes; + + /// + /// Rebuild the list of types from assemblies loaded in the AsmCtxLoader. + /// + internal void RebuildTypesList() + { + ClearTypesList(); + try + { + _assembliesTypes = this.Acl.Assemblies + .SelectMany(a => a.GetSafeTypes()) + .ToImmutableDictionary(t => t.FullName ?? t.Name, t => t); + } + catch(ArgumentException _) + { + // some types must've had duplicate type names, build the list while filtering + Dictionary types = new(); + foreach (var type in this.Acl.Assemblies.SelectMany(a => a.GetSafeTypes())) + { + try + { + types.TryAdd(type.FullName ?? type.Name, type); + } + catch + { + // ignore, null key exception + } + } + + _assembliesTypes = types.ToImmutableDictionary(); + } + } + + internal void ClearTypesList() + { + _assembliesTypes.Clear(); + } + } + + #endregion +} + +public static class AssemblyExtensions +{ + /// + /// Gets all types in the given assembly. Handles invalid type scenarios. + /// + /// The assembly to scan + /// An enumerable collection of types. + public static IEnumerable GetSafeTypes(this Assembly assembly) + { + // Based on https://github.com/Qkrisi/ktanemodkit/blob/master/Assets/Scripts/ReflectionHelper.cs#L53-L67 + + try + { + return assembly.GetTypes(); + } + catch (ReflectionTypeLoadException re) + { + try + { + return re.Types.Where(x => x != null)!; + } + catch (InvalidOperationException ioe) + { + return new List(); + } + } + catch (Exception e) + { + return new List(); + } + } +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/CsPackageManager.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/CsPackageManager.cs new file mode 100644 index 0000000000..188536a672 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/CsPackageManager.cs @@ -0,0 +1,976 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading; +using Barotrauma.Steam; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using MonoMod.Utils; + +namespace Barotrauma; + +public sealed class CsPackageManager : IDisposable +{ + #region PRIVATE_FUNCDATA + + private static readonly CSharpParseOptions ScriptParseOptions = CSharpParseOptions.Default + .WithPreprocessorSymbols(new[] + { +#if SERVER + "SERVER" +#elif CLIENT + "CLIENT" +#else + "UNDEFINED" +#endif +#if DEBUG + ,"DEBUG" +#endif + }); + +#if WINDOWS + private const string PLATFORM_TARGET = "Windows"; +#elif OSX + private const string PLATFORM_TARGET = "OSX"; +#elif LINUX + private const string PLATFORM_TARGET = "Linux"; +#endif + +#if CLIENT + private const string ARCHITECTURE_TARGET = "Client"; +#elif SERVER + private const string ARCHITECTURE_TARGET = "Server"; +#endif + + private static readonly CSharpCompilationOptions CompilationOptions = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary) + .WithMetadataImportOptions(MetadataImportOptions.All) +#if DEBUG + .WithOptimizationLevel(OptimizationLevel.Debug) +#else + .WithOptimizationLevel(OptimizationLevel.Release) +#endif + .WithAllowUnsafe(true); + + private static readonly SyntaxTree BaseAssemblyImports = CSharpSyntaxTree.ParseText( + new StringBuilder() + .AppendLine("using System.Reflection;") + .AppendLine("using Barotrauma;") + .AppendLine("using System.Runtime.CompilerServices;") +#if CLIENT + .AppendLine("[assembly: IgnoresAccessChecksTo(\"Barotrauma\")]") +#elif SERVER + .AppendLine("[assembly: IgnoresAccessChecksTo(\"DedicatedServer\")]") +#endif + .ToString(), + ScriptParseOptions); + + private const string SCRIPT_FILE_REGEX = "*.cs"; + private const string ASSEMBLY_FILE_REGEX = "*.dll"; + + private readonly float _assemblyUnloadTimeoutSeconds = 4f; + private Guid _publicizedAssemblyLoader; + private readonly List _currentPackagesByLoadOrder = new(); + private readonly Dictionary> _packagesDependencies = new(); + private readonly Dictionary _loadedCompiledPackageAssemblies = new(); + private readonly Dictionary _reverseLookupGuidList = new(); + private readonly Dictionary> _loadedPlugins = new (); + private readonly Dictionary> _pluginTypes = new(); // where Type : IAssemblyPlugin + private readonly Dictionary _packageRunConfigs = new(); + private readonly Dictionary> _luaRegisteredTypes = new(); + private readonly AssemblyManager _assemblyManager; + private readonly LuaCsSetup _luaCsSetup; + private DateTime _assemblyUnloadStartTime; + + + #endregion + + #region PUBLIC_API + + #region LUA_EXTENSIONS + + /// + /// Searches for all types in all loaded assemblies from content packages who's names contain the name string and registers them with the Lua Interpreter. + /// + /// + /// + /// + public bool LuaTryRegisterPackageTypes(string name, bool caseSensitive = false) + { + if (!AssembliesLoaded) + return false; + var matchingPacks = _loadedCompiledPackageAssemblies + .Where(kvp => kvp.Key.Name.ToLowerInvariant().Contains(name.ToLowerInvariant())) + .Select(kvp => kvp.Value) + .ToImmutableList(); + if (!matchingPacks.Any()) + return false; + var types = matchingPacks + .Where(guid => !_luaRegisteredTypes.ContainsKey(guid)) + .Select(guid => new KeyValuePair>( + guid, + _assemblyManager.TryGetSubTypesFromACL(guid, out var types) + ? types.ToImmutableList() + : ImmutableList.Empty)) + .ToImmutableList(); + if (!types.Any()) + return false; + foreach (var kvp in types) + { + _luaRegisteredTypes[kvp.Key] = kvp.Value; + foreach (Type type in kvp.Value) + { + MoonSharp.Interpreter.UserData.RegisterType(type); + } + } + + return true; + } + + #endregion + + /// + /// Whether or not assemblies have been loaded. + /// + public bool AssembliesLoaded { get; private set; } + + + /// + /// Whether or not loaded plugins had their preloader run. + /// + public bool PluginsPreInit { get; private set; } + + /// + /// Whether or not plugins' types have been instantiated. + /// + public bool PluginsInitialized { get; private set; } = false; + + /// + /// Whether or not plugins are fully loaded. + /// + public bool PluginsLoaded { get; private set; } = false; + + public IEnumerable GetCurrentPackagesByLoadOrder() => _currentPackagesByLoadOrder; + + /// + /// Tries to find the content package that a given plugin belongs to. + /// + /// Package if found, null otherwise. + /// The IAssemblyPlugin type to find. + /// + public bool TryGetPackageForPlugin(out ContentPackage package) where T : IAssemblyPlugin + { + package = null; + + var t = typeof(T); + var guid = _pluginTypes + .Where(kvp => kvp.Value.Contains(t)) + .Select(kvp => kvp.Key) + .FirstOrDefault(Guid.Empty); + + if (guid.Equals(Guid.Empty) || !_reverseLookupGuidList.ContainsKey(guid) || _reverseLookupGuidList[guid] is null) + return false; + package = _reverseLookupGuidList[guid]; + return true; + } + + + /// + /// Tries to get the loaded plugins for a given package. + /// + /// Package to find. + /// The collection of loaded plugins. + /// + public bool TryGetLoadedPluginsForPackage(ContentPackage package, out IEnumerable loadedPlugins) + { + loadedPlugins = null; + if (package is null || !_loadedCompiledPackageAssemblies.ContainsKey(package)) + return false; + var guid = _loadedCompiledPackageAssemblies[package]; + if (guid.Equals(Guid.Empty) || !_loadedPlugins.ContainsKey(guid)) + return false; + loadedPlugins = _loadedPlugins[guid]; + return true; + } + + /// + /// Called when clean up is being performed. Use when relying on or making use of references from this manager. + /// + public event Action OnDispose; + + public void Dispose() + { + // send events for cleanup + OnDispose?.Invoke(); + // cleanup events + if (OnDispose is not null) + { + foreach (Delegate del in OnDispose.GetInvocationList()) + { + OnDispose -= (del as System.Action); + } + } + + // cleanup plugins and assemblies + ReflectionUtils.ResetCache(); + UnloadPlugins(); + + // try cleaning up the assemblies + _pluginTypes.Clear(); // remove assembly references + _loadedPlugins.Clear(); + + // lua cleanup + foreach (var kvp in _luaRegisteredTypes) + { + foreach (Type type in kvp.Value) + { + MoonSharp.Interpreter.UserData.UnregisterType(type); + } + } + _luaRegisteredTypes.Clear(); + + _assemblyUnloadStartTime = DateTime.Now; + _publicizedAssemblyLoader = Guid.Empty; + + // we can't wait forever or app dies but we can try to be graceful + while (!_assemblyManager.TryBeginDispose()) + { + if (_assemblyUnloadStartTime.AddSeconds(_assemblyUnloadTimeoutSeconds) > DateTime.Now) + { + break; + } + } + + _assemblyUnloadStartTime = DateTime.Now; + while (!_assemblyManager.FinalizeDispose()) + { + if (_assemblyUnloadStartTime.AddSeconds(_assemblyUnloadTimeoutSeconds) > DateTime.Now) + { + break; + } + } + + _assemblyManager.OnAssemblyLoaded -= AssemblyManagerOnAssemblyLoaded; + _assemblyManager.OnAssemblyUnloading -= AssemblyManagerOnAssemblyUnloading; + + _publicizedAssemblyLoader = Guid.Empty; + + // clear lists after cleaning up + _packagesDependencies.Clear(); + _loadedCompiledPackageAssemblies.Clear(); + _reverseLookupGuidList.Clear(); + _packageRunConfigs.Clear(); + _currentPackagesByLoadOrder.Clear(); + + AssembliesLoaded = false; + GC.SuppressFinalize(this); + } + + /// + /// Begins the loading process of scanning packages for scripts and binary assemblies, compiling and executing them. + /// + /// + public AssemblyLoadingSuccessState LoadAssemblyPackages() + { + if (AssembliesLoaded) + { + return AssemblyLoadingSuccessState.AlreadyLoaded; + } + + _assemblyManager.OnAssemblyLoaded += AssemblyManagerOnAssemblyLoaded; + _assemblyManager.OnAssemblyUnloading += AssemblyManagerOnAssemblyUnloading; + + // load publicized assemblies + var publicizedDir = Path.Combine(Environment.CurrentDirectory, "Publicized"); + ImmutableList publicizedAssemblies = ImmutableList.Empty; + if (Directory.Exists(publicizedDir)) + { + // search for assemblies + var list = Directory.GetFiles(publicizedDir, "*.dll") +#if CLIENT + .Where(s => !s.ToLowerInvariant().EndsWith("dedicatedserver.dll")); +#elif SERVER + .Where(s => !s.ToLowerInvariant().EndsWith("barotrauma.dll")); +#endif + + // try load them into an acl + var loadState = _assemblyManager.LoadAssembliesFromLocations(list, ref _publicizedAssemblyLoader); + + // loaded + if (loadState is AssemblyLoadingSuccessState.Success) + { + if (_assemblyManager.TryGetACL(_publicizedAssemblyLoader, out var acl)) + { + publicizedAssemblies = acl.Acl.Assemblies.ToImmutableList(); + _assemblyManager.SetACLToTemplateMode(_publicizedAssemblyLoader); + } + } + } + + + // get packages + IEnumerable packages = BuildPackagesList(); + + // check and load config + _packageRunConfigs.AddRange(packages + .Select(p => new KeyValuePair(p, GetRunConfigForPackage(p))) + .ToDictionary(p => p.Key, p=> p.Value)); + + // filter not to be loaded + var cpToRun = _packageRunConfigs + .Where(kvp => ShouldRunPackage(kvp.Key, kvp.Value)) + .Select(kvp => kvp.Key) + .ToImmutableList(); + + // build dependencies map + bool reliableMap = TryBuildDependenciesMap(cpToRun, out var packDeps); + if (!reliableMap) + { + ModUtils.Logging.PrintMessage($"{nameof(CsPackageManager)}: Unable to create reliable dependencies map."); + } + + _packagesDependencies.AddRange(packDeps.ToDictionary( + kvp => kvp.Key, + kvp => kvp.Value.ToImmutableList()) + ); + + List packagesToLoadInOrder = new(); + + // build load order + if (reliableMap && OrderAndFilterPackagesByDependencies( + _packagesDependencies, + out var readyToLoad, + out var cannotLoadPackages, + null)) + { + packagesToLoadInOrder.AddRange(readyToLoad); + if (cannotLoadPackages is not null) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the following mods due to dependency errors:"); + foreach (var pair in cannotLoadPackages) + { + ModUtils.Logging.PrintError($"Package: {pair.Key.Name} | Reason: {pair.Value}"); + } + } + } + else + { + // use unsorted list on failure and send error message. + packagesToLoadInOrder.AddRange(_packagesDependencies.Select( p=> p.Key)); + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to create a reliable load order. Defaulting to unordered loading!"); + } + + // get assemblies and scripts' filepaths from packages + var toLoad = packagesToLoadInOrder + .Select(cp => new KeyValuePair( + cp, + new LoadableData( + TryScanPackagesForAssemblies(cp, out var list1) ? list1 : null, + TryScanPackageForScripts(cp, out var list2) ? list2 : null))) + .ToImmutableDictionary(); + + HashSet badPackages = new(); + foreach (var pair in toLoad) + { + // check if unloadable + if (badPackages.Contains(pair.Key)) + continue; + + // try load binary assemblies + var id = Guid.Empty; // id for the ACL for this package defined by AssemblyManager. + AssemblyLoadingSuccessState successState; + if (pair.Value.AssembliesFilePaths is not null && pair.Value.AssembliesFilePaths.Any()) + { + ModUtils.Logging.PrintMessage($"Loading assemblies for CPackage {pair.Key.Name}"); + foreach (string assembliesFilePath in pair.Value.AssembliesFilePaths) + { + ModUtils.Logging.PrintMessage($"Found assemblies located at {Path.GetFullPath(ModUtils.IO.SanitizePath(assembliesFilePath))}"); + } + + successState = _assemblyManager.LoadAssembliesFromLocations(pair.Value.AssembliesFilePaths, ref id); + + // error handling + if (successState is not AssemblyLoadingSuccessState.Success) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the binary assemblies for package {pair.Key.Name}. Error: {successState.ToString()}"); + UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies); + continue; + } + } + + // try compile scripts to assemblies + if (pair.Value.ScriptsFilePaths is not null && pair.Value.ScriptsFilePaths.Any()) + { + ModUtils.Logging.PrintMessage($"Loading scripts for CPackage {pair.Key.Name}"); + List syntaxTrees = new(); + + syntaxTrees.Add(GetPackageScriptImports()); + bool abortPackage = false; + // load scripts data from files + foreach (string scriptPath in pair.Value.ScriptsFilePaths) + { + var state = ModUtils.IO.GetOrCreateFileText(scriptPath, out string fileText, null, false); + // could not load file data + if (state is not ModUtils.IO.IOActionResultState.Success) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the script files for package {pair.Key.Name}. Error: {state.ToString()}"); + UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies); + abortPackage = true; + break; + } + + try + { + CancellationToken token = new(); + syntaxTrees.Add(SyntaxFactory.ParseSyntaxTree(fileText, ScriptParseOptions, scriptPath, Encoding.Default, token)); + // cancel if parsing failed + if (token.IsCancellationRequested) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the script files for package {pair.Key.Name}. Error: Syntax Parse Error."); + UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies); + abortPackage = true; + break; + } + } + catch (Exception e) + { + // unknown error + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the script files for package {pair.Key.Name}. Error: {e.Message}"); + UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies); + abortPackage = true; + break; + } + + } + + if (abortPackage) + continue; + + // try compile + successState = _assemblyManager.LoadAssemblyFromMemory( + pair.Key.Name.Replace(" ",""), + syntaxTrees, + null, + CompilationOptions, + ref id, publicizedAssemblies); + + if (successState is not AssemblyLoadingSuccessState.Success) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to compile script assembly for package {pair.Key.Name}. Error: {successState.ToString()}"); + UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies); + continue; + } + } + + // something was loaded, add to index + if (id != Guid.Empty) + { + ModUtils.Logging.PrintMessage($"Assemblies from CPackage {pair.Key.Name} loaded with Guid {id}."); + _loadedCompiledPackageAssemblies.Add(pair.Key, id); + _reverseLookupGuidList.Add(id, pair.Key); + } + } + + // update loaded packages to exclude bad packages + _currentPackagesByLoadOrder.AddRange(toLoad + .Where(p => !badPackages.Contains(p.Key)) + .Select(p => p.Key)); + + // build list of plugins + foreach (var pair in _loadedCompiledPackageAssemblies) + { + if (_assemblyManager.TryGetSubTypesFromACL(pair.Value, out var types)) + { + _pluginTypes[pair.Value] = types.ToImmutableHashSet(); + foreach (var type in _pluginTypes[pair.Value]) + { + ModUtils.Logging.PrintMessage($"Loading type: {type.Name}"); + } + } + } + + this.AssembliesLoaded = true; + return AssemblyLoadingSuccessState.Success; + + + bool ShouldRunPackage(ContentPackage package, RunConfig config) + { + if (config.AutoGenerated) + return false; + return (!_luaCsSetup.Config.TreatForcedModsAsNormal && config.IsForced()) + || (ContentPackageManager.EnabledPackages.All.Contains(package) && config.IsForcedOrStandard()); + } + + void UpdatePackagesToDisable(ref HashSet list, + ContentPackage newDisabledPackage, + IEnumerable>> dependenciesMap) + { + list.Add(newDisabledPackage); + foreach (var package in dependenciesMap) + { + if (package.Value.Contains(newDisabledPackage)) + list.Add(newDisabledPackage); + } + } + } + + /// + /// Executes instantiated plugins' Initialize() and OnLoadCompleted() methods. + /// + public void RunPluginsInit() + { + if (!AssembliesLoaded) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to call plugins' Initialize() without any loaded assemblies!"); + return; + } + + if (!PluginsInitialized) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to call plugins' Initialize() without type instantiation!"); + return; + } + + if (PluginsLoaded) + return; + + foreach (var contentPlugins in _loadedPlugins) + { + // init + foreach (var plugin in contentPlugins.Value) + { + TryRun(() => plugin.Initialize(), $"{nameof(IAssemblyPlugin.Initialize)}", plugin.GetType().Name); + } + } + + foreach (var contentPlugins in _loadedPlugins) + { + // load complete + foreach (var plugin in contentPlugins.Value) + { + TryRun(() => plugin.OnLoadCompleted(), $"{nameof(IAssemblyPlugin.OnLoadCompleted)}", plugin.GetType().Name); + } + } + + PluginsLoaded = true; + } + + /// + /// Executes instantiated plugins' PreInitPatching() method. + /// + public void RunPluginsPreInit() + { + if (!AssembliesLoaded) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to call plugins' PreInitPatching() without any loaded assemblies!"); + return; + } + + if (!PluginsInitialized) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to call plugins' PreInitPatching() without type initialization!"); + return; + } + + if (PluginsPreInit) + { + return; + } + + foreach (var contentPlugins in _loadedPlugins) + { + // init + foreach (var plugin in contentPlugins.Value) + { + TryRun(() => plugin.PreInitPatching(), $"{nameof(IAssemblyPlugin.PreInitPatching)}", plugin.GetType().Name); + } + } + + PluginsPreInit = true; + } + + /// + /// Initializes plugin types that are registered. + /// + /// + public void InstantiatePlugins(bool force = false) + { + if (!AssembliesLoaded) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to instantiate plugins without any loaded assemblies!"); + return; + } + + if (PluginsInitialized) + { + if (force) + UnloadPlugins(); + else + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to load plugins when they were already loaded!"); + return; + } + } + + foreach (var pair in _pluginTypes) + { + // instantiate + foreach (Type type in pair.Value) + { + if (!_loadedPlugins.ContainsKey(pair.Key)) + _loadedPlugins.Add(pair.Key, new()); + else if (_loadedPlugins[pair.Key] is null) + _loadedPlugins[pair.Key] = new(); + IAssemblyPlugin plugin = null; + try + { + plugin = (IAssemblyPlugin)Activator.CreateInstance(type); + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Error while instantiating plugin of type {type}. Now disposing..."); +#if DEBUG + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Details: {e.Message} | {e.InnerException}"); +#endif + TryRun(() => plugin?.Dispose(), "Dispose", type.FullName ?? type.Name); + + plugin = null; + } + if (plugin is not null) + _loadedPlugins[pair.Key].Add(plugin); + else + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Error while instantiating plugin of type {type}"); + } + } + + PluginsInitialized = true; + } + + /// + /// Unloads all plugins by calling Dispose() on them. Note: This does not remove their external references nor + /// unregister their types. + /// + public void UnloadPlugins() + { + foreach (var contentPlugins in _loadedPlugins) + { + foreach (var plugin in contentPlugins.Value) + { + TryRun(() => plugin.Dispose(), $"{nameof(IAssemblyPlugin.Dispose)}", plugin.GetType().Name); + } + contentPlugins.Value.Clear(); + } + + _loadedPlugins.Clear(); + + PluginsInitialized = false; + PluginsPreInit = false; + PluginsLoaded = false; + } + + + /// + /// Gets the RunConfig.xml for the given package located at [cp_root]/CSharp/RunConfig.xml. + /// Generates a default config if one is not found. + /// + /// The package to search for. + /// RunConfig data. + /// True if a config is loaded, false if one was created. + public static bool GetOrCreateRunConfig(ContentPackage package, out RunConfig config) + { + var path = System.IO.Path.Combine(Path.GetFullPath(package.Dir), "CSharp", "RunConfig.xml"); + if (!File.Exists(path)) + { + config = new RunConfig(true).Sanitize(); + return false; + } + return ModUtils.IO.LoadOrCreateTypeXml(out config, path, () => new RunConfig(true).Sanitize(), false); + } + + #endregion + + #region INTERNALS + + private void TryRun(Action action, string messageMethodName, string messageTypeName) + { + try + { + action?.Invoke(); + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Error while running {messageMethodName}() on plugin of type {messageTypeName}"); +#if DEBUG + ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Details: {e.Message} | {e.InnerException}"); +#endif + } + } + + private void AssemblyManagerOnAssemblyUnloading(Assembly assembly) + { + ReflectionUtils.RemoveAssemblyFromCache(assembly); + } + + private void AssemblyManagerOnAssemblyLoaded(Assembly assembly) + { + //ReflectionUtils.AddNonAbstractAssemblyTypes(assembly); + // As ReflectionUtils.GetDerivedNonAbstract is only used for Prefabs & Barotrauma-specific implementing types, + // we can safely not register System/Core assemblies. + if (assembly.FullName is not null && assembly.FullName.StartsWith("System.")) + return; + ReflectionUtils.AddNonAbstractAssemblyTypes(assembly, true); + } + + internal CsPackageManager([NotNull] AssemblyManager assemblyManager, [NotNull] LuaCsSetup luaCsSetup) + { + this._assemblyManager = assemblyManager; + this._luaCsSetup = luaCsSetup; + } + + ~CsPackageManager() + { + this.Dispose(); + } + + private static bool TryScanPackageForScripts(ContentPackage package, out ImmutableList scriptFilePaths) + { + string pathShared = Path.Combine(ModUtils.IO.GetContentPackageDir(package), "CSharp", "Shared"); + string pathArch = Path.Combine(ModUtils.IO.GetContentPackageDir(package), "CSharp", ARCHITECTURE_TARGET); + + List files = new(); + + if (Directory.Exists(pathShared)) + files.AddRange(Directory.GetFiles(pathShared, SCRIPT_FILE_REGEX, SearchOption.AllDirectories)); + if (Directory.Exists(pathArch)) + files.AddRange(Directory.GetFiles(pathArch, SCRIPT_FILE_REGEX, SearchOption.AllDirectories)); + + if (files.Count > 0) + { + scriptFilePaths = files.ToImmutableList(); + return true; + } + scriptFilePaths = ImmutableList.Empty; + return false; + } + + private static bool TryScanPackagesForAssemblies(ContentPackage package, out ImmutableList assemblyFilePaths) + { + string path = Path.Combine(ModUtils.IO.GetContentPackageDir(package), "bin", ARCHITECTURE_TARGET, PLATFORM_TARGET); + + if (!Directory.Exists(path)) + { + assemblyFilePaths = ImmutableList.Empty; + return false; + } + + assemblyFilePaths = System.IO.Directory.GetFiles(path, ASSEMBLY_FILE_REGEX, SearchOption.AllDirectories) + .ToImmutableList(); + return assemblyFilePaths.Count > 0; + } + + private static RunConfig GetRunConfigForPackage(ContentPackage package) + { + if (!GetOrCreateRunConfig(package, out var config)) + config.AutoGenerated = true; + return config; + } + + private IEnumerable BuildPackagesList() + { + // get unique list of content packages. + // Note: there is an old issue where the AllPackages group + // would sometimes not contain packages downloaded from the host, so we union enabled. + return ContentPackageManager.AllPackages.Union(ContentPackageManager.EnabledPackages.All).Where(pack => !pack.Name.ToLowerInvariant().Equals("vanilla")); + } + + + private static SyntaxTree GetPackageScriptImports() => BaseAssemblyImports; + + + /// + /// Builds a list of ContentPackage dependencies for each of the packages in the list. Note: All dependencies must be included in the provided list of packages. + /// + /// List of packages to check + /// Dependencies by package + /// True if all dependencies were found. + private static bool TryBuildDependenciesMap(ImmutableList packages, out Dictionary> dependenciesMap) + { + bool reliableMap = true; // remains true if all deps were found. + dependenciesMap = new(); + foreach (var package in packages) + { + dependenciesMap.Add(package, new()); + if (GetOrCreateRunConfig(package, out var config)) + { + if (config.Dependencies is null || !config.Dependencies.Any()) + continue; + + foreach (RunConfig.Dependency dependency in config.Dependencies) + { + ContentPackage dep = packages.FirstOrDefault(p => + (dependency.SteamWorkshopId != 0 && p.TryExtractSteamWorkshopId(out var steamWorkshopId) + && steamWorkshopId.Value == dependency.SteamWorkshopId) + || (!dependency.PackageName.IsNullOrWhiteSpace() && p.Name.ToLowerInvariant().Contains(dependency.PackageName.ToLowerInvariant())), null); + + if (dep is not null) + { + dependenciesMap[package].Add(dep); + } + else + { + ModUtils.Logging.PrintError($"Warning! The ContentPackage {package.Name} lists a dependency of (STEAMID: {dependency.SteamWorkshopId}, PackageName: {dependency.PackageName}) but it could not be found in the to-be-loaded CSharp packages list!"); + reliableMap = false; + } + } + } + else + { + ModUtils.Logging.PrintMessage($"Warning! Could not retrieve RunConfig for ContentPackage {package.Name}!"); + } + } + + return reliableMap; + } + + /// + /// Given a table of packages and dependent packages, will sort them by dependency loading order along with packages + /// that cannot be loaded due to errors or failing the predicate checks. + /// + /// A dictionary/map with key as the package and the elements as it's dependencies. + /// List of packages that are ready to load and in the correct order. + /// Packages with errors or cyclic dependencies. Element is error message. Null if empty. + /// Optional: Allows for a custom checks to be performed on each package. + /// Returns a bool indicating if the package is ready to load. + /// Whether or not the process produces a usable list. + private static bool OrderAndFilterPackagesByDependencies( + Dictionary> packages, + out IEnumerable readyToLoad, + out IEnumerable> cannotLoadPackages, + Func packageChecksPredicate = null) + { + HashSet completedPackages = new(); + List readyPackages = new(); + Dictionary unableToLoad = new(); + HashSet currentNodeChain = new(); + + readyToLoad = readyPackages; + + try + { + foreach (var toProcessPack in packages) + { + ProcessPackage(toProcessPack.Key, toProcessPack.Value); + } + + PackageProcRet ProcessPackage(ContentPackage packageToProcess, IEnumerable dependencies) + { + //cyclic handling + if (unableToLoad.ContainsKey(packageToProcess)) + { + return PackageProcRet.BadPackage; + } + + // already processed + if (completedPackages.Contains(packageToProcess)) + { + return PackageProcRet.AlreadyCompleted; + } + + // cyclic check + if (currentNodeChain.Contains(packageToProcess)) + { + StringBuilder sb = new(); + sb.AppendLine("Error: Cyclic Dependency. ") + .Append( + "The following ContentPackages rely on eachother in a way that makes it impossible to know which to load first! ") + .Append( + "Note: the package listed twice shows where the cycle starts/ends and is not necessarily the problematic package."); + int i = 0; + foreach (var package in currentNodeChain) + { + i++; + sb.AppendLine($"{i}. {package.Name}"); + } + + sb.AppendLine($"{i}. {packageToProcess.Name}"); + unableToLoad.Add(packageToProcess, sb.ToString()); + completedPackages.Add(packageToProcess); + return PackageProcRet.BadPackage; + } + + if (packageChecksPredicate is not null && !packageChecksPredicate.Invoke(packageToProcess)) + { + unableToLoad.Add(packageToProcess, $"Unable to load package {packageToProcess.Name} due to failing checks."); + completedPackages.Add(packageToProcess); + return PackageProcRet.BadPackage; + } + + currentNodeChain.Add(packageToProcess); + + foreach (ContentPackage dependency in dependencies) + { + // The mod lists a dependent that was not found during the discovery phase. + if (!packages.ContainsKey(dependency)) + { + // search to see if it's enabled + if (!ContentPackageManager.EnabledPackages.All.Contains(dependency)) + { + // present warning but allow loading anyways, better to let the user just disable the package if it's really an issue. + ModUtils.Logging.PrintError( + $"Warning: the ContentPackage of {packageToProcess.Name} requires the Dependency {dependency.Name} but this package wasn't found in the enabled mods list!"); + } + + continue; + } + + var ret = ProcessPackage(dependency, packages[dependency]); + + if (ret is PackageProcRet.BadPackage) + { + if (!unableToLoad.ContainsKey(packageToProcess)) + { + unableToLoad.Add(packageToProcess, $"Error: Dependency failure. Failed to load {dependency.Name}"); + } + currentNodeChain.Remove(packageToProcess); + if (!completedPackages.Contains(packageToProcess)) + { + completedPackages.Add(packageToProcess); + } + return PackageProcRet.BadPackage; + } + } + + currentNodeChain.Remove(packageToProcess); + completedPackages.Add(packageToProcess); + readyPackages.Add(packageToProcess); + return PackageProcRet.Completed; + } + } + catch (Exception e) + { + ModUtils.Logging.PrintError($"Error while generating dependency loading order! Exception: {e.Message}"); +#if DEBUG + ModUtils.Logging.PrintError($"Stack Trace: {e.StackTrace}"); +#endif + cannotLoadPackages = unableToLoad.Any() ? unableToLoad : null; + return false; + } + cannotLoadPackages = unableToLoad.Any() ? unableToLoad : null; + return true; + } + + private enum PackageProcRet : byte + { + AlreadyCompleted, + Completed, + BadPackage + } + + private record LoadableData(ImmutableList AssembliesFilePaths, ImmutableList ScriptsFilePaths); + + #endregion +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/IAssemblyPlugin.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/IAssemblyPlugin.cs new file mode 100644 index 0000000000..5a450ba744 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/IAssemblyPlugin.cs @@ -0,0 +1,22 @@ +using System; + +namespace Barotrauma; + +public interface IAssemblyPlugin : IDisposable +{ + /// + /// Called on plugin normal, use this for basic/core loading that does not rely on any other modded content. + /// + void Initialize(); + + /// + /// Called once all plugins have been loaded. if you have integrations with any other mod, put that code here. + /// + void OnLoadCompleted(); + + + /// + /// Called before Barotrauma initializes vanilla content. WARNING: This method may be called before Initialize()! + /// + void PreInitPatching(); +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/MemoryFileAssemblyContextLoader.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/MemoryFileAssemblyContextLoader.cs new file mode 100644 index 0000000000..8c29f07e44 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/MemoryFileAssemblyContextLoader.cs @@ -0,0 +1,289 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.Loader; +using System.Threading; +using Barotrauma; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Emit; + +namespace Barotrauma; + +/// +/// AssemblyLoadContext to compile from syntax trees in memory and to load from disk/file. Provides dependency resolution. +/// [IMPORTANT] Only supports 1 in-memory compiled assembly at a time. Use more instances if you need more. +/// [IMPORTANT] All file assemblies required for the compilation of syntax trees should be loaded first. +/// +public class MemoryFileAssemblyContextLoader : AssemblyLoadContext +{ + // public + // ReSharper disable MemberCanBePrivate.Global + public Assembly CompiledAssembly { get; private set; } = null; + public byte[] CompiledAssemblyImage { get; private set; } = null; + // ReSharper restore MemberCanBePrivate.Global + // internal + private readonly Dictionary _dependencyResolvers = new(); // path-folder, resolver + protected bool IsResolving; //this is to avoid circular dependency lookup. + private AssemblyManager _assemblyManager; + public bool IsTemplateMode { get; set; } = false; + + public MemoryFileAssemblyContextLoader(AssemblyManager assemblyManager) : base(isCollectible: true) + { + this._assemblyManager = assemblyManager; + } + + + /// + /// Try to load the list of disk-file assemblies. + /// + /// Operation success or failure reason. + public AssemblyLoadingSuccessState LoadFromFiles([NotNull] IEnumerable assemblyFilePaths) + { + if (assemblyFilePaths is null) + throw new ArgumentNullException( + $"{nameof(MemoryFileAssemblyContextLoader)}::{nameof(LoadFromFiles)}() | The supplied filepath list is null."); + + foreach (string filepath in assemblyFilePaths) + { + // path verification + if (filepath.IsNullOrWhiteSpace()) + continue; + string sanitizedFilePath = System.IO.Path.GetFullPath(filepath.CleanUpPath()); + string directoryKey = System.IO.Path.GetDirectoryName(sanitizedFilePath); + + if (directoryKey is null) + return AssemblyLoadingSuccessState.BadFilePath; + + // setup dep resolver if not available + if (!_dependencyResolvers.ContainsKey(directoryKey) || _dependencyResolvers[directoryKey] is null) + { + _dependencyResolvers[directoryKey] = new AssemblyDependencyResolver(sanitizedFilePath); // supply the first assembly to be loaded + } + + // try loading the assemblies + try + { + LoadFromAssemblyPath(sanitizedFilePath); + } + // on fail of any we're done because we assume that loaded files are related. This ACL needs to be unloaded and collected. + catch (ArgumentNullException ane) + { + return AssemblyLoadingSuccessState.BadFilePath; + } + catch (ArgumentException ae) + { + return AssemblyLoadingSuccessState.BadFilePath; + } + catch (FileLoadException fle) + { + return AssemblyLoadingSuccessState.CannotLoadFile; + } + catch (FileNotFoundException fne) + { + return AssemblyLoadingSuccessState.NoAssemblyFound; + } + catch (BadImageFormatException bfe) + { + return AssemblyLoadingSuccessState.InvalidAssembly; + } + catch (Exception e) + { +#if SERVER + LuaCsLogger.LogError($"Unable to load dependency assembly file at {filepath.CleanUpPath()} for the assembly named {CompiledAssembly?.FullName}. | Data: {e.Message} | InnerException: {e.InnerException}"); +#elif CLIENT + LuaCsLogger.ShowErrorOverlay($"Unable to load dependency assembly file at {filepath} for the assembly named {CompiledAssembly?.FullName}. | Data: {e.Message} | InnerException: {e.InnerException}"); +#endif + return AssemblyLoadingSuccessState.ACLLoadFailure; + } + } + + return AssemblyLoadingSuccessState.Success; + } + + + /// + /// Compiles the supplied syntaxtrees and options into an in-memory assembly image. + /// Builds metadata from loaded assemblies, only supply your own if you have in-memory images not managed by the + /// AssemblyManager class. + /// + /// Name of the assembly. Must be supplied for in-memory assemblies. + /// Syntax trees to compile into the assembly. + /// Metadata to be used for compilation. + /// [IMPORTANT] This method builds metadata from loaded assemblies, only supply your own if you have in-memory + /// images not managed by the AssemblyManager class. + /// CSharp compilation options. This method automatically adds the 'IgnoreAccessChecks' property for compilation. + /// Will contain any diagnostic messages for compilation failure. + /// Additional assemblies located in the FileSystem to build metadata references from. + /// Assemblies here will have duplicates by the same name that are currently loaded filtered out. + /// Success state of the operation. + /// Throws exception if any of the required arguments are null. + public AssemblyLoadingSuccessState CompileAndLoadScriptAssembly( + [NotNull] string assemblyName, + [NotNull] IEnumerable syntaxTrees, + IEnumerable externMetadataReferences, + [NotNull] CSharpCompilationOptions compilationOptions, + out string compilationMessages, + IEnumerable externFileAssemblyReferences = null) + { + compilationMessages = ""; + + if (this.CompiledAssembly is not null) + { + return AssemblyLoadingSuccessState.AlreadyLoaded; + } + + var externAssemblyRefs = externFileAssemblyReferences is not null ? externFileAssemblyReferences.ToImmutableList() : ImmutableList.Empty; + var externAssemblyNames = externAssemblyRefs.Any() ? externAssemblyRefs + .Where(a => a.FullName is not null) + .Select(a => a.FullName).ToImmutableHashSet() + : ImmutableHashSet.Empty; + + // verifications + if (assemblyName.IsNullOrWhiteSpace()) + throw new ArgumentNullException( + $"{nameof(MemoryFileAssemblyContextLoader)}::{nameof(CompileAndLoadScriptAssembly)}() | The supplied assembly name is null!"); + + if (syntaxTrees is null) + throw new ArgumentNullException( + $"{nameof(MemoryFileAssemblyContextLoader)}::{nameof(CompileAndLoadScriptAssembly)}() | The supplied syntax tree is null!"); + + // add external references + List metadataReferences = new(); + if (externMetadataReferences is not null) + metadataReferences.AddRange(externMetadataReferences); + + // build metadata refs from global where not an in-memory compiled assembly and not the same assembly as supplied. + metadataReferences.AddRange(AppDomain.CurrentDomain.GetAssemblies() + .Where(a => + { + if (a.IsDynamic || string.IsNullOrEmpty(a.Location) || a.Location.Contains("xunit")) + return false; + if (a.FullName is null) + return true; + return !externAssemblyNames.Contains(a.FullName); // exclude duplicates + }) + .Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference) + .Union(externAssemblyRefs // add custom supplied assemblies + .Where(a => !(a.IsDynamic || string.IsNullOrEmpty(a.Location) || a.Location.Contains("xunit"))) + .Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference) + ).ToList()); + + // build metadata refs from in-memory images + foreach (var loadedAcl in _assemblyManager.GetAllLoadedACLs()) + { + if (loadedAcl.Acl.CompiledAssemblyImage is null || loadedAcl.Acl.CompiledAssemblyImage.Length == 0) + continue; + metadataReferences.Add(MetadataReference.CreateFromImage(loadedAcl.Acl.CompiledAssemblyImage)); + } + + // Change inaccessible options to allow public access to restricted members + var topLevelBinderFlagsProperty = typeof(CSharpCompilationOptions).GetProperty("TopLevelBinderFlags", BindingFlags.Instance | BindingFlags.NonPublic); + topLevelBinderFlagsProperty?.SetValue(compilationOptions, (uint)1 << 22); + + // begin compilation + using var memoryCompilation = new MemoryStream(); + // compile, emit + var result = CSharpCompilation.Create(assemblyName, syntaxTrees, metadataReferences, compilationOptions).Emit(memoryCompilation); + // check for errors + if (!result.Success) + { + IEnumerable failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error); + foreach (Diagnostic diagnostic in failures) + { + compilationMessages += $"\n{diagnostic}"; + } + + return AssemblyLoadingSuccessState.InvalidAssembly; + } + + // read compiled assembly from memory stream into an in-memory assembly & image + memoryCompilation.Seek(0, SeekOrigin.Begin); // reset + try + { + CompiledAssembly = LoadFromStream(memoryCompilation); + CompiledAssemblyImage = memoryCompilation.ToArray(); + } + catch (Exception e) + { +#if SERVER + LuaCsLogger.LogError($"Unable to load memory assembly from stream. | Data: {e.Message} | InnerException: {e.InnerException}"); +#elif CLIENT + LuaCsLogger.ShowErrorOverlay($"Unable to load memory assembly from stream. | Data: {e.Message} | InnerException: {e.InnerException}"); +#endif + return AssemblyLoadingSuccessState.CannotLoadFromStream; + } + + return AssemblyLoadingSuccessState.Success; + } + + [SuppressMessage("ReSharper", "ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract")] + protected override Assembly Load(AssemblyName assemblyName) + { + if (IsResolving) + return null; //circular resolution fast exit. + + try + { + IsResolving = true; + + // resolve self collection + Assembly ass = this.Assemblies.FirstOrDefault(a => + a.FullName is not null && a.FullName.Equals(assemblyName.FullName), null); + if (ass is not null) + return ass; + + // resolve to local folders + foreach (KeyValuePair pair in _dependencyResolvers) + { + var asspath = pair.Value.ResolveAssemblyToPath(assemblyName); + if (asspath is null) + continue; + ass = LoadFromAssemblyPath(asspath); + // ReSharper disable once ConditionIsAlwaysTrueOrFalse + if (ass is not null) + return ass; + } + + //try resolve against other loaded alcs + foreach (var loadedAcL in _assemblyManager.GetAllLoadedACLs()) + { + if (loadedAcL.Acl is null || loadedAcL.Acl.IsTemplateMode) continue; + + try + { + ass = loadedAcL.Acl.LoadFromAssemblyName(assemblyName); + if (ass is not null) + return ass; + } + catch + { + // LoadFromAssemblyName throws, no need to propagate + } + } + + ass = AssemblyLoadContext.Default.LoadFromAssemblyName(assemblyName); + if (ass is not null) + return ass; + } + finally + { + IsResolving = false; + } + + return null; + } + + + private new void Unload() + { + CompiledAssembly = null; + CompiledAssemblyImage = null; + base.Unload(); + } +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/RunConfig.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/RunConfig.cs new file mode 100644 index 0000000000..93cbb75f78 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Plugins/RunConfig.cs @@ -0,0 +1,111 @@ +using System; +using System.Xml.Serialization; + +namespace Barotrauma; + +[Serializable] +public sealed class RunConfig +{ + /// + /// How should scripts be run on the server. + /// + [XmlElement(ElementName = "Server")] public string Server; + + /// + /// How should scripts be run on the client. + /// + [XmlElement(ElementName = "Client")] public string Client; + + /// + /// List of dependencies by either Steam Workshop ID or by Partial Inclusive Name (ie. "ModDep" will match a mod named "A ModDependency"). + /// PIN Dependency checks if ContentPackage names contains the dependency string. + /// + [XmlArrayItem(ElementName = "Dependency", IsNullable = true, Type = typeof(Dependency))] + [XmlArray] + public Dependency[] Dependencies { get; set; } + + [XmlElement(ElementName = "AutoGenerated")] + public bool AutoGenerated { get; set; } + + public RunConfig(bool autoGenerated) + { + this.AutoGenerated = autoGenerated; + if (autoGenerated) + { + (Client, Server) = ("None", "None"); + } + } + + public RunConfig() { } // For serialization use + + [Serializable] + public sealed class Dependency + { + /// + /// Steam Workshop ID of the dependency. + /// + [XmlElement(ElementName = "SteamWorkshopId")] + public ulong SteamWorkshopId; + + /// + /// Package Name of the dependency. Not needed if SteamWorkshopId is set. + /// + [XmlElement(ElementName = "PackageName")] + public string PackageName; + } + + public RunConfig Sanitize() + { + try + { + Client = SanitizeRunSetting(Client); + } + catch (Exception e) + { + Client = "None"; + } + + try + { + Server = SanitizeRunSetting(Server); + } + catch (Exception e) + { + Server = "None"; + } + + Dependencies ??= new RunConfig.Dependency[] { }; + + static string SanitizeRunSetting(string str) => + str switch + { + null => "None", + "" => "None", + " " => "None", + _ => str[0].ToString().ToUpper() + str.Substring(1).ToLower() + }; + + return this; + } + + public bool IsForced() + { +#if CLIENT + return this.Client.Equals("Forced"); +#elif SERVER + return this.Server.Equals("Forced"); +#endif + } + + public bool IsStandard() + { +#if CLIENT + return this.Client.Equals("Standard"); +#elif SERVER + return this.Server.Equals("Standard"); +#endif + } + + public bool IsForcedOrStandard() => this.IsForced() || this.IsStandard(); + +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/Utils/ReflectionUtils.cs b/Barotrauma/BarotraumaShared/SharedSource/Utils/ReflectionUtils.cs index f3e6bcdf7a..ac8c601ac7 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/Utils/ReflectionUtils.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/Utils/ReflectionUtils.cs @@ -1,5 +1,6 @@ #nullable enable using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; @@ -9,26 +10,43 @@ namespace Barotrauma { public static class ReflectionUtils { - private static readonly Dictionary> cachedNonAbstractTypes - = new Dictionary>(); - - private static readonly Dictionary>> cachedDerivedNonAbstract - = new Dictionary>>(); + private static readonly ConcurrentDictionary> CachedNonAbstractTypes + = new ConcurrentDictionary>(); + private static readonly ConcurrentDictionary> TypeSearchCache = new(); public static IEnumerable GetDerivedNonAbstract() { Type t = typeof(T); - Assembly assembly = typeof(T).Assembly; - lock (cachedNonAbstractTypes) + string typeName = t.FullName ?? t.Name; + + // search quick lookup cache + if (TypeSearchCache.TryGetValue(typeName, out var value)) { - if (!cachedNonAbstractTypes.ContainsKey(assembly)) - { - AddNonAbstractAssemblyTypes(assembly); - } + return value; } + + // doesn't exist so let's add it. + Assembly assembly = typeof(T).Assembly; + if (!CachedNonAbstractTypes.ContainsKey(assembly)) + { + AddNonAbstractAssemblyTypes(assembly); + } + + // build cache from registered assemblies' types. + var list = CachedNonAbstractTypes.Values + .SelectMany(arr => arr.Where(type => type.IsSubclassOf(t))) + .ToImmutableArray(); - #warning TODO: Add safety checks in case an assembly is unloaded without being removed from the cache. - return cachedNonAbstractTypes.Values.SelectMany(s => s.Where(t => t.IsSubclassOf(typeof(T)))); + if (list.Length == 0) + { + return ImmutableArray.Empty; // No types, don't add to cache + } + + if (!TypeSearchCache.TryAdd(typeName, list)) + { + DebugConsole.LogError($"ReflectionUtils::AddNonAbstractAssemblyTypes() | Error while adding to quick lookup cache."); + } + return list; } /// @@ -38,7 +56,7 @@ public static IEnumerable GetDerivedNonAbstract() /// Whether or not to overwrite an entry if the assembly already exists within it. public static void AddNonAbstractAssemblyTypes(Assembly assembly, bool overwrite = false) { - if (cachedNonAbstractTypes.ContainsKey(assembly)) + if (CachedNonAbstractTypes.ContainsKey(assembly)) { if (!overwrite) { @@ -46,15 +64,20 @@ public static void AddNonAbstractAssemblyTypes(Assembly assembly, bool overwrite $"ReflectionUtils::AddNonAbstractAssemblyTypes() | The assembly [{assembly.GetName()}] already exists in the cache."); return; } - cachedNonAbstractTypes.Remove(assembly); + + CachedNonAbstractTypes.Remove(assembly, out _); } try { - if (!cachedNonAbstractTypes.TryAdd(assembly, assembly.GetTypes().Where(t => !t.IsAbstract).ToImmutableArray())) + if (!CachedNonAbstractTypes.TryAdd(assembly, assembly.GetSafeTypes().Where(t => !t.IsAbstract).ToImmutableArray())) { DebugConsole.LogError($"ReflectionUtils::AddNonAbstractAssemblyTypes() | Unable to add types from Assembly to cache."); } + else + { + TypeSearchCache.Clear(); // Needs to be rebuilt to include potential new types + } } catch (ReflectionTypeLoadException e) { @@ -66,8 +89,22 @@ public static void AddNonAbstractAssemblyTypes(Assembly assembly, bool overwrite /// Removes an assembly from the cache for Barotrauma's Type lookup. /// /// Assembly to remove. - public static void RemoveAssemblyFromCache(Assembly assembly) => cachedNonAbstractTypes.Remove(assembly); + public static void RemoveAssemblyFromCache(Assembly assembly) + { + CachedNonAbstractTypes.Remove(assembly, out _); + TypeSearchCache.Clear(); + } + /// + /// Clears all cached assembly data and rebuilds types list only to include base Barotrauma types. + /// + internal static void ResetCache() + { + CachedNonAbstractTypes.Clear(); + CachedNonAbstractTypes.TryAdd(typeof(ReflectionUtils).Assembly, typeof(ReflectionUtils).Assembly.GetSafeTypes().ToImmutableArray()); + TypeSearchCache.Clear(); + } + public static Option ParseDerived(TInput input) where TInput : notnull where TBase : notnull { static Option none() => Option.None(); diff --git a/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs b/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs index 652ee758da..813d6fee8a 100644 --- a/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs +++ b/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs @@ -259,10 +259,10 @@ private class PatchTargetModifyParams { public bool ran; - public void Run(int a, out string outString, ref byte refByte, string b) + public void Run(out string result, int a, string b, ref byte c) { ran = true; - outString = a + b + refByte; + result = a + b + c; } } @@ -274,12 +274,13 @@ public void TestModifyParameters() using var patchHandle = luaCs.AddPrefix(@" ptable['a'] = Int32(100) ptable['b'] = 'abc' - ptable['refByte'] = Byte(4) + ptable['c'] = Byte(4) ", nameof(PatchTargetModifyParams.Run)); - byte refByte = 123; - target.Run(5, out var outString, ref refByte, "foo"); + byte c = 123; + target.Run(out var result, 5, "foo", ref c); Assert.True(target.ran); - Assert.Equal("100abc4", outString); + Assert.Equal(4, c); + Assert.Equal("100abc4", result); } private class PatchTargetVector2 @@ -309,13 +310,63 @@ public void TestParameterValueType() private class PatchTargetAmbiguous { + public bool ran; + public PatchTargetAmbiguous() { } - public PatchTargetAmbiguous(int a) { } + public PatchTargetAmbiguous(int a) + { + throw new NotImplementedException(); + } - public void Blah() { } + public void Run(out string result, int a, string b, ref byte c) + { + ran = true; + result = a + b + c; + } - public void Blah(int a) { } + public void Run(string result, int a, string b, byte c) + { + throw new NotImplementedException(); + } + + public void Run(out string result, int a, string b, byte c) + { + throw new NotImplementedException(); + } + + public void Run(string result, int a, string b, ref byte c) + { + throw new NotImplementedException(); + } + + public void Run() + { + throw new NotImplementedException(); + } + } + + [Fact] + public void TestPatchDisambiguation() + { + using var patchTargetHandle = HookPatchHelpers.LockPatchTarget(); + var target = new PatchTargetAmbiguous(); + using var patchHandle = luaCs.AddPrefix(@" + ptable['a'] = Int32(100) + ptable['b'] = 'abc' + ptable['c'] = Byte(4) + ", nameof(PatchTargetAmbiguous.Run), new[] + { + $"out {typeof(string).FullName!}", + typeof(int).FullName!, + typeof(string).FullName!, + $"ref {typeof(byte).FullName!}", + }); + byte c = 123; + target.Run(out var result, 5, "foo", ref c); + Assert.True(target.ran); + Assert.Equal(4, c); + Assert.Equal("100abc4", result); } [Fact] @@ -334,11 +385,11 @@ public void TestPatchAmbiguous() Assert.Throws(() => { - using var postfixHandle = luaCs.AddPostfix("", nameof(PatchTargetAmbiguous.Blah)); + using var postfixHandle = luaCs.AddPostfix("", nameof(PatchTargetAmbiguous.Run)); }); Assert.Throws(() => { - using var prefixHandle = luaCs.AddPrefix("", nameof(PatchTargetAmbiguous.Blah)); + using var prefixHandle = luaCs.AddPrefix("", nameof(PatchTargetAmbiguous.Run)); }); } diff --git a/luacs-docs/lua/lua/Networking.lua b/luacs-docs/lua/lua/Networking.lua index a87cd1fa6d..016df5ff86 100644 --- a/luacs-docs/lua/lua/Networking.lua +++ b/luacs-docs/lua/lua/Networking.lua @@ -16,13 +16,17 @@ Networking.FileSenderMaxPacketsPerUpdate = 4 -- @realm server Networking.LastClientListUpdateID = 0 ---- Send a post HTTP Request, callback is called with an argument result string. --- @realm server +--- Send a GET HTTP Request, callback is called with the result string message, status code and headers. only url anda callback are optional. +-- @realm shared function Networking.HttpGet(url, callback, textData, contentType) end ---- Send a get HTTP Request, callback is called with an argument result string. --- @realm server -function Networking.HttpPost(url, callback) end +--- Send a POST HTTP Request, callback is called with the result string message, status code and headers. +-- @realm shared +function Networking.HttpPost(url, callback, textData, contentType) end + +--- Sends a HTTP Request, callback is called with the result string message, status code and headers. If savePath is specified, the result will be saved as binary format in the specified path. only url and callback are optional. +-- @realm shared +function Networking.HttpRequest(url, callback, data, method, contentType, headers, savePath) end --- Creates a new net message, returns an IWriteMessage -- @treturn IWriteMessage netMessage