Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/rdmp 73 holdouts extractions #1648

Draft
wants to merge 28 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions Rdmp.Core.Tests/DataExport/DataExtraction/ExtractionHoldoutTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright (c) The University of Dundee 2018-2023
// This file is part of the Research Data Management Platform (RDMP).
// RDMP is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
// RDMP is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
// You should have received a copy of the GNU General Public License along with RDMP. If not, see <https://www.gnu.org/licenses/>.

using NUnit.Framework;
using Rdmp.Core.DataExport.DataExtraction.Pipeline;
using Rdmp.Core.DataFlowPipeline;
using System;
using System.IO;
using Rdmp.Core.ReusableLibraryCode.Checks;
using Rdmp.Core.ReusableLibraryCode.Progress;
using System.Data;
using NUnit.Framework.Internal;
using Rdmp.Core.DataExport.DataExtraction.Commands;
using Rdmp.Core.DataExport.DataExtraction.UserPicks;
using Tests.Common.Scenarios;

namespace Rdmp.Core.Tests.DataExport.DataExtraction;

internal class ExtractionHoldoutTests: TestsRequiringAnExtractionConfiguration
{
private ExtractionHoldout _holdout;
private DirectoryInfo _holdoutDir;
private DataTable _toProcess;
private IExtractCommand _ExtractionStub;

[SetUp]
protected override void SetUp()
{
base.SetUp();
_holdout = new ExtractionHoldout();
_holdoutDir = new DirectoryInfo(Path.Combine(TestContext.CurrentContext.WorkDirectory, "Holdout"));
Console.WriteLine(Path.Combine(TestContext.CurrentContext.WorkDirectory));
if(_holdoutDir.Exists)
{
_holdoutDir.Delete(true);
}
_holdoutDir.Create();
_toProcess = new DataTable();
_toProcess.Columns.Add("FAKE_CHI");
_toProcess.Rows.Add(1);
_toProcess.Rows.Add(2);
_toProcess.Rows.Add(3);
_toProcess.Rows.Add(4);
_toProcess.Rows.Add(5);
_toProcess.Rows.Add(6);
}

[Test]
public void NoConfiguration()
{
var ex = Assert.Throws<Exception>(() => _holdout.Check(ThrowImmediatelyCheckNotifier.Quiet));
Assert.IsTrue(ex.Message.Contains("No holdout file location set."));
}

[Test]
public void LocationSet()
{
_holdout.holdoutStorageLocation = _holdoutDir.FullName;
var ex = Assert.Throws<Exception>(() => _holdout.Check(ThrowImmediatelyCheckNotifier.Quiet));
Assert.IsTrue(ex.Message.Contains("No data holdout count set."));
}

[Test]
public void LocationAndCountSet()
{
_holdout.holdoutStorageLocation = _holdoutDir.FullName;
_holdout.holdoutCount = 1;
_holdout.Check(ThrowImmediatelyCheckNotifier.Quiet);
Console.WriteLine(_holdoutDir.FullName);
FileAssert.DoesNotExist(Path.Combine(_holdoutDir.FullName, "holdout_TestTable.csv"));//todo use the correct name
}

[Test]
public void ExtractionHoldoutSingle()
{
_holdout.holdoutStorageLocation = _holdoutDir.FullName;
_holdout.holdoutCount = 1;
_ExtractionStub = new ExtractDatasetCommand(_configuration, new ExtractableDatasetBundle(_extractableDataSet));
_holdout.PreInitialize(_ExtractionStub, ThrowImmediatelyDataLoadEventListener.Quiet);
_holdout.ProcessPipelineData(_toProcess, ThrowImmediatelyDataLoadEventListener.Quiet, new GracefulCancellationToken());
_holdout.Dispose(ThrowImmediatelyDataLoadEventListener.Quiet, null);
Assert.IsTrue(File.Exists(Path.Combine(_holdoutDir.FullName, "holdout_TestTable.csv")));
String expectedOutput = File.ReadAllText(Path.Combine(_holdoutDir.FullName, "holdout_TestTable.csv"));
Assert.That(expectedOutput, Does.Match("FAKE_CHI\r\n[1-6]\r\n"));

}

[Test]
public void ExtractionHoldoutPercentage()
{
_holdout.holdoutStorageLocation = _holdoutDir.FullName;
_holdout.holdoutCount = 33;
_holdout.isPercentage = true;
_ExtractionStub = new ExtractDatasetCommand(_configuration, new ExtractableDatasetBundle(_extractableDataSet));
_holdout.PreInitialize(_ExtractionStub, ThrowImmediatelyDataLoadEventListener.Quiet);
_holdout.ProcessPipelineData(_toProcess, ThrowImmediatelyDataLoadEventListener.Quiet, new GracefulCancellationToken());
_holdout.Dispose(ThrowImmediatelyDataLoadEventListener.Quiet, null);
Assert.IsTrue(File.Exists(Path.Combine(_holdoutDir.FullName, "holdout_TestTable.csv")));
String expectedOutput = File.ReadAllText(Path.Combine(_holdoutDir.FullName, "holdout_TestTable.csv"));
Assert.That(expectedOutput, Does.Match("FAKE_CHI\r\n[1-6]\r\n[1-6]\r\n"));

}
[Test]
public void ExtractionHoldoutPercentageAppend()
{
_holdout.holdoutStorageLocation = _holdoutDir.FullName;
_holdout.holdoutCount = 33;
_holdout.isPercentage = true;
_holdout.overrideFile = false;
_ExtractionStub = new ExtractDatasetCommand(_configuration, new ExtractableDatasetBundle(_extractableDataSet));
_holdout.PreInitialize(_ExtractionStub, ThrowImmediatelyDataLoadEventListener.Quiet);
_holdout.ProcessPipelineData(_toProcess, ThrowImmediatelyDataLoadEventListener.Quiet, new GracefulCancellationToken());
_holdout.Dispose(ThrowImmediatelyDataLoadEventListener.Quiet, null);
Assert.IsTrue(File.Exists(Path.Combine(_holdoutDir.FullName, "holdout_TestTable.csv")));
String expectedOutput = File.ReadAllText(Path.Combine(_holdoutDir.FullName, "holdout_TestTable.csv"));
Assert.That(expectedOutput, Does.Match("FAKE_CHI\r\n[1-6]\r\n[1-6]\r\n"));

_holdout.PreInitialize(_ExtractionStub, ThrowImmediatelyDataLoadEventListener.Quiet);
_holdout.ProcessPipelineData(_toProcess, ThrowImmediatelyDataLoadEventListener.Quiet, new GracefulCancellationToken());
_holdout.Dispose(ThrowImmediatelyDataLoadEventListener.Quiet, null);
Assert.IsTrue(File.Exists(Path.Combine(_holdoutDir.FullName, "holdout_TestTable.csv")));
expectedOutput = File.ReadAllText(Path.Combine(_holdoutDir.FullName, "holdout_TestTable.csv"));
Assert.That(expectedOutput, Does.Match("FAKE_CHI\r\n[1-6]\r\n[1-6]\r\n[1-6]\r\n[1-6]\r\n"));

}

}
243 changes: 243 additions & 0 deletions Rdmp.Core/DataExport/DataExtraction/Pipeline/ExtractionHoldout.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
// Copyright (c) The University of Dundee 2018-2023
// This file is part of the Research Data Management Platform (RDMP).
// RDMP is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
// RDMP is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
// You should have received a copy of the GNU General Public License along with RDMP. If not, see <https://www.gnu.org/licenses/>.

using Rdmp.Core.Curation.Data;
using Rdmp.Core.DataExport.DataExtraction.Commands;
using Rdmp.Core.DataFlowPipeline;
using Rdmp.Core.DataFlowPipeline.Requirements;
using Rdmp.Core.ReusableLibraryCode.Checks;
using Rdmp.Core.ReusableLibraryCode.Progress;
using System;
using System.Collections.Generic;
using System.Data;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text;

namespace Rdmp.Core.DataExport.DataExtraction.Pipeline;

public class ExtractionHoldout : IPluginDataFlowComponent<DataTable>, IPipelineRequirement<IExtractCommand>
{
[DemandsInitialization("The % of the data you want to be kelp as holdout data")]
public int holdoutCount { get; set; }

[DemandsInitialization("Use a % as holdout value. If unselected, the actual number will be used.")]
public bool isPercentage { get; set; }

[DemandsInitialization("Write the holdout data to disk. Leave blank if you don't want it exported somewhere")]
public string holdoutStorageLocation { get; set; }

[DemandsInitialization("Set this value to only select data for holdout that is before this date")]
public DateTime beforeDate { get; set; }

[DemandsInitialization("Set this value to only select data for holdout that is after this date")]
public DateTime afterDate { get; set; }

[DemandsInitialization("The column that the before and after date options use to filter holdout data")]
public string dateColumn { get; set; }

//can only filter on strings, not dates
[DemandsInitialization("Allows for the filtering of what data can be used as holdout data. The filter only currently supports filtering on string columns, not dates. Filter References https://learn.microsoft.com/en-us/dotnet/api/system.data.dataview.rowfilter?view=net-7.0 and https://learn.microsoft.com/en-us/dotnet/api/system.data.datacolumn.expression?view=net-7.0")]
public string whereCondition { get; set; }

[DemandsInitialization("Overrides any data in the holdout file with new data")]
public bool overrideFile { get; set; }


// We may want to automatically reimport into RDMP, but this is quite complicated.
// It may be worth having users reimport the catalogue themself until it is proven that this is worth building.
//Currently only support writting holdback data to a CSV


private readonly string holdoutColumnName = "_isValidHoldout";


public IExtractDatasetCommand Request { get; private set; }


private bool ValidateIfRowShouldBeFiltered(DataRow row, DataTable toProcess)
{
if (!string.IsNullOrWhiteSpace(dateColumn))
{
//had a data column
DateTime dateCell;
try
{
dateCell = row.Field<DateTime>(dateColumn);
}
catch (Exception)
{
dateCell = DateTime.Parse(row.Field<string>(dateColumn), CultureInfo.InvariantCulture);
}
Comment on lines +77 to +80

Check notice

Code scanning / CodeQL

Generic catch clause Note

Generic catch clause.

if (afterDate != DateTime.MinValue && dateCell <= afterDate)
{
//has date
return false;
}
if (beforeDate != DateTime.MinValue && dateCell >= beforeDate)
{
//has date
return false;
}
}
if (!string.IsNullOrWhiteSpace(whereCondition))
{
DataTable dt = toProcess.Clone();
dt.ImportRow(row);
DataView dv = new DataView(dt);
Fixed Show fixed Hide fixed
dv.RowFilter = whereCondition;
DataTable dt2 = dv.ToTable();
dv.Dispose();
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
if (dt2.Rows.Count < 1)
{
return false;
}
}
return true;
}

private void FilterRowsBasedOnHoldoutDates(DataTable toProcess)
{
toProcess.Columns.Add(holdoutColumnName, typeof(bool));
foreach (DataRow row in toProcess.Rows)
{
row[holdoutColumnName] = ValidateIfRowShouldBeFiltered(row, toProcess);
}
}

private int GetHoldoutRowCount(DataTable toProcess, IDataLoadEventListener listener)
{

float rowCount = (float)holdoutCount;
if (rowCount >= toProcess.Rows.Count && !isPercentage)
{
listener.OnNotify(this, new NotifyEventArgs(ProgressEventType.Warning, "More holdout data was requested than there is available data. All valid data will be held back"));
rowCount = toProcess.Rows.Count;
}
if (isPercentage)
{
if (holdoutCount > 100)
{
listener.OnNotify(this, new NotifyEventArgs(ProgressEventType.Warning, "Holdout percentage was >100%. Will use 100%"));
holdoutCount = 100;
}
rowCount = (float)toProcess.Rows.Count / 100 * holdoutCount;
}
return (int)Math.Ceiling(rowCount);
}

public void PreInitialize(IExtractCommand request, IDataLoadEventListener listener)
{
Request = request as IExtractDatasetCommand;

if (Request == null)
return;
}

private void WriteDataTabletoCSV(DataTable dt)
{
StringBuilder sb = new();
string filename = Request.ToString();
string path = $"{holdoutStorageLocation}/holdout_{filename}.csv";
IEnumerable<string> columnNames = dt.Columns.Cast<DataColumn>().Select(column => column.ColumnName);
if (overrideFile || !File.Exists(path))
{
sb.AppendLine(string.Join(",", columnNames));
}

foreach (DataRow row in dt.Rows)
{
IEnumerable<string> fields = row.ItemArray.Select(field => field.ToString());
sb.AppendLine(string.Join(",", fields));
}
holdoutStorageLocation.TrimEnd('/');
holdoutStorageLocation.TrimEnd('\\');

if (File.Exists(path) && !overrideFile)
{

using (StreamWriter sw = File.AppendText(path))
{
sw.WriteLine(sb.ToString());
}
return;

}

File.WriteAllText(path, sb.ToString());
}

public DataTable ProcessPipelineData(DataTable toProcess, IDataLoadEventListener listener, GracefulCancellationToken cancellationToken)
{
if (toProcess.Rows.Count == 0)
{
return toProcess;
}
bool toProcessDTModified = false;
if (dateColumn is not null && (afterDate != DateTime.MinValue || beforeDate != DateTime.MinValue))
{
//we only want to check for valid rows if dates are set, otherwise all rows are valid
FilterRowsBasedOnHoldoutDates(toProcess);
toProcessDTModified = true;
}

DataTable holdoutData = toProcess.Clone();
int foundHoldoutCount = GetHoldoutRowCount(toProcess, listener);
var rand = new Random();
holdoutData.BeginLoadData();
toProcess.BeginLoadData();

var rowsToMove = toProcess.AsEnumerable().Where(row => !toProcessDTModified || row[holdoutColumnName] is true).OrderBy(r => rand.Next()).Take(foundHoldoutCount);
if (rowsToMove.Count() < 1)
{
listener.OnNotify(this, new NotifyEventArgs(ProgressEventType.Warning, "No valid holdout rows were found. Please check your settings."));

}
foreach (DataRow row in rowsToMove)
{
holdoutData.ImportRow(row);
toProcess.Rows.Remove(row);
}
holdoutData.EndLoadData();
toProcess.EndLoadData();

if (holdoutStorageLocation is not null && holdoutStorageLocation.Length > 0)
{
if (toProcessDTModified)
{
holdoutData.Columns.Remove(holdoutColumnName);
}
WriteDataTabletoCSV(holdoutData);
}
if (toProcessDTModified)
{
toProcess.Columns.Remove(holdoutColumnName);

}

return toProcess;
}

public void Check(ICheckNotifier notifier)
{
if (string.IsNullOrWhiteSpace(holdoutStorageLocation))
{
notifier.OnCheckPerformed(new CheckEventArgs($"No holdout file location set.", CheckResult.Fail));
}
if (holdoutCount is 0)
{
notifier.OnCheckPerformed(new CheckEventArgs($"No data holdout count set.", CheckResult.Fail));
}
}
public void Abort(IDataLoadEventListener listener)
{
}
public void Dispose(IDataLoadEventListener listener, Exception pipelineFailureExceptionIfAny)
{
}
}