diff --git a/Create.cs b/Create.cs index 5899198..ea07568 100644 --- a/Create.cs +++ b/Create.cs @@ -116,17 +116,23 @@ private async Task InstallPackageDependencies(string language) switch (language) { case "Python (Tensorflow)": - packageRequirements.Add(CheckForDiscreteGPU() ? "tensorflow" : "tensorflow"); + packageRequirements.Add(CheckForDiscreteGPU() ? "tensorflow-gpu" : "tensorflow"); if (comboBox1.Text == "Text generation (RNN)") { packageRequirements.Add("numpy"); - await DownloadFilesForRNN(textBox1.Text); + await DownloadFilesForTGRNN(textBox1.Text); } break; case "Python (PyTorch)": packageRequirements.Add("torch"); + + if (textBox1.Text == "Text generation (LSTM)") + { + packageRequirements.Add("numpy"); + await DownloadFilesForTGLSTM(textBox1.Text); + } break; } await UpdateProgressBarAsync(35); @@ -143,7 +149,7 @@ private async Task InstallPackageDependencies(string language) await UpdateProgressBarAsync(50); } - private async Task DownloadFilesForRNN(string projectPath) + private async Task DownloadFilesForTGRNN(string projectPath) { using (var client = new WebClient()) { @@ -162,6 +168,25 @@ private async Task DownloadFilesForRNN(string projectPath) } } + private async Task DownloadFilesForTGLSTM(string projectPath) + { + using (var client = new WebClient()) + { + try + { + var baseUri = "https://raw.githubusercontent.com/Lithicsoft/Lithicsoft-Trainer-Studio/main/lstm_text_generation/"; + await client.DownloadFileTaskAsync(new Uri(baseUri + "trainer.py"), $"projects\\{projectPath}\\trainer.py"); + await client.DownloadFileTaskAsync(new Uri(baseUri + ".env"), $"projects\\{projectPath}\\.env"); + await UpdateProgressBarAsync(20); + } + catch (WebException ex) + { + MessageBox.Show($"Error downloading file: {ex.Message}", "Exception Error", MessageBoxButtons.OK, MessageBoxIcon.Error); + Environment.Exit(1); + } + } + } + private void textBox1_TextChanged(object sender, EventArgs e) { if (textBox1.Text.Length > 0 && comboBox1.Text.Length > 0 && comboBox2.Text.Length > 0) @@ -196,7 +221,7 @@ private void comboBox2_SelectedIndexChanged(object sender, EventArgs e) } else if (comboBox2.Text == "Python (PyTorch)") { - comboBox1.Items.AddRange(["Text generation"]); + comboBox1.Items.AddRange(["Text generation (LSTM)"]); } else if (comboBox2.Text == "Python (Tensorflow)") { diff --git a/Python.Designer.cs b/Python.Designer.cs index 4a4caed..ab89bcb 100644 --- a/Python.Designer.cs +++ b/Python.Designer.cs @@ -30,9 +30,6 @@ private void InitializeComponent() { label1 = new Label(); tabPage5 = new TabPage(); - richTextBox3 = new RichTextBox(); - textBox4 = new TextBox(); - button6 = new Button(); label7 = new Label(); button5 = new Button(); textBox2 = new TextBox(); @@ -73,9 +70,6 @@ private void InitializeComponent() // // tabPage5 // - tabPage5.Controls.Add(richTextBox3); - tabPage5.Controls.Add(textBox4); - tabPage5.Controls.Add(button6); tabPage5.Controls.Add(label7); tabPage5.Controls.Add(button5); tabPage5.Controls.Add(textBox2); @@ -87,42 +81,14 @@ private void InitializeComponent() tabPage5.Text = "Result"; tabPage5.UseVisualStyleBackColor = true; // - // richTextBox3 - // - richTextBox3.Location = new Point(3, 93); - richTextBox3.Name = "richTextBox3"; - richTextBox3.ReadOnly = true; - richTextBox3.Size = new Size(495, 239); - richTextBox3.TabIndex = 8; - richTextBox3.Text = ""; - // - // textBox4 - // - textBox4.Location = new Point(3, 64); - textBox4.Name = "textBox4"; - textBox4.Size = new Size(414, 23); - textBox4.TabIndex = 7; - textBox4.TextChanged += textBox4_TextChanged; - // - // button6 - // - button6.Enabled = false; - button6.Location = new Point(423, 64); - button6.Name = "button6"; - button6.Size = new Size(75, 23); - button6.TabIndex = 6; - button6.Text = "Predict"; - button6.UseVisualStyleBackColor = true; - button6.Click += button6_Click; - // // label7 // label7.AutoSize = true; - label7.Location = new Point(3, 46); + label7.Location = new Point(3, 44); label7.Name = "label7"; - label7.Size = new Size(64, 15); + label7.Size = new Size(249, 15); label7.TabIndex = 4; - label7.Text = "Test model"; + label7.Text = "Pipeline is not available for this model/project\r\n"; // // button5 // @@ -149,9 +115,9 @@ private void InitializeComponent() label6.AutoSize = true; label6.Location = new Point(3, 0); label6.Name = "label6"; - label6.Size = new Size(60, 15); + label6.Size = new Size(73, 15); label6.TabIndex = 0; - label6.Text = "Model file"; + label6.Text = "Result folder"; // // tabPage3 // @@ -353,16 +319,13 @@ private void InitializeComponent() private TabPage tabPage3; private Label label6; private Button button5; - private Button button6; private Label label7; private ProgressBar progressBar2; private TextBox textBox2; - private TextBox textBox4; private RichTextBox richTextBox2; private TabPage tabPage2; private Button button4; private TextBox textBox3; private ListView listView1; - private RichTextBox richTextBox3; } } diff --git a/Python.cs b/Python.cs index c308dcd..4f78179 100644 --- a/Python.cs +++ b/Python.cs @@ -30,10 +30,7 @@ public Python(string name, string language, string type) projectName = name; label1.Text = $"{type} with {language}"; - if (Directory.Exists($"projects\\{projectName}\\model")) - { - textBox2.Text = $"projects\\{projectName}\\model"; - } + textBox2.Text = $"projects\\{projectName}"; trainParameters = DotEnv.Load($"projects\\{projectName}\\.env"); listView1.View = View.Details; @@ -160,6 +157,8 @@ await Task.Run(() => { MessageBox.Show($"Error training model: {ex.Message}", "Exception Error", MessageBoxButtons.OK, MessageBoxIcon.Error); } + + textBox2.Text = $"projects\\{projectName}"; button3.Enabled = true; } @@ -230,18 +229,6 @@ private void button5_Click(object sender, EventArgs e) } } - private void textBox4_TextChanged(object sender, EventArgs e) - { - if (textBox4.Text.Length > 0) - { - button6.Enabled = true; - } - else - { - button6.Enabled = false; - } - } - private void listView1_SelectedIndexChanged(object sender, EventArgs e) { if (listView1.SelectedItems.Count > 0) diff --git a/obj/Debug/net8.0-windows/Lithicsoft Trainer Studio.AssemblyInfo.cs b/obj/Debug/net8.0-windows/Lithicsoft Trainer Studio.AssemblyInfo.cs index d458756..0e62a00 100644 --- a/obj/Debug/net8.0-windows/Lithicsoft Trainer Studio.AssemblyInfo.cs +++ b/obj/Debug/net8.0-windows/Lithicsoft Trainer Studio.AssemblyInfo.cs @@ -14,7 +14,7 @@ [assembly: System.Reflection.AssemblyCompanyAttribute("Lithicsoft Trainer Studio")] [assembly: System.Reflection.AssemblyConfigurationAttribute("Debug")] [assembly: System.Reflection.AssemblyFileVersionAttribute("1.0.0.0")] -[assembly: System.Reflection.AssemblyInformationalVersionAttribute("1.0.0+4b53c44cdd402e0cd2955ebed05599f3988fc34e")] +[assembly: System.Reflection.AssemblyInformationalVersionAttribute("1.0.0+8562805c7a99d4a44f355d67b61eabb6a223c099")] [assembly: System.Reflection.AssemblyProductAttribute("Lithicsoft Trainer Studio")] [assembly: System.Reflection.AssemblyTitleAttribute("Lithicsoft Trainer Studio")] [assembly: System.Reflection.AssemblyVersionAttribute("1.0.0.0")] diff --git a/obj/Debug/net8.0-windows/Lithicsoft Trainer Studio.AssemblyInfoInputs.cache b/obj/Debug/net8.0-windows/Lithicsoft Trainer Studio.AssemblyInfoInputs.cache index ce05103..92f1fa9 100644 --- a/obj/Debug/net8.0-windows/Lithicsoft Trainer Studio.AssemblyInfoInputs.cache +++ b/obj/Debug/net8.0-windows/Lithicsoft Trainer Studio.AssemblyInfoInputs.cache @@ -1 +1 @@ -db032ac3aa1a6f6390f5d57811a8d4160a9a54be789895012f815c01332c81c5 +775123c45f918709ae8514842fbdca02b15f0e0657ca360315273dd4fa094d29