From 7fd670f3af2d418ed21cd26313ca8dee18514522 Mon Sep 17 00:00:00 2001 From: Najeeb Kazmi Date: Fri, 27 Sep 2019 20:33:30 -0700 Subject: [PATCH] Add entrypoint for PFI (#4232) * Add entrypoint for PFI * Regenerate EP catalog * Add tests * Adding Standard Error of Mean to PFI Metrics in EntryPoint * PR Feedback * Remove MLContext from EntryPoint * Use last predictor in the model if model.LastTransformer is not a predictor * nit * Model file path conflicts in tests * nit * PR Feedback * Pass in model as PredictorModel * Remove label column and group ID column from entrypoint input arguments * nit * Simplify data prep --- .../PermutationFeatureImportance.cs | 429 ++++++++++ .../PermutationFeatureImportanceExtensions.cs | 4 +- .../Common/EntryPoints/core_ep-list.tsv | 1 + .../Common/EntryPoints/core_manifest.json | 73 ++ .../UnitTests/TestEntryPoints.cs | 737 +++++++++++++++++- 5 files changed, 1241 insertions(+), 3 deletions(-) create mode 100644 src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs new file mode 100644 index 0000000000..135ad3b333 --- /dev/null +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -0,0 +1,429 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; + +[assembly: LoadableClass(typeof(void), typeof(PermutationFeatureImportanceEntryPoints), null, typeof(SignatureEntryPointModule), "PermutationFeatureImportance")] + +namespace Microsoft.ML.Transforms +{ + internal static class PermutationFeatureImportanceEntryPoints + { + [TlcModule.EntryPoint(Name = "Transforms.PermutationFeatureImportance", Desc = "Permutation Feature Importance (PFI)", UserName = "PFI", ShortName = "PFI")] + public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IHostEnvironment env, PermutationFeatureImportanceArguments input) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register("Pfi"); + host.CheckValue(input, nameof(input)); + EntryPointUtils.CheckInputArgs(host, input); + + input.PredictorModel.PrepareData(env, input.Data, out RoleMappedData roleMappedData, out IPredictor predictor); + Contracts.Assert(predictor != null, "No predictor found in model"); + IDataView result = PermutationFeatureImportanceUtils.GetMetrics(env, predictor, roleMappedData, input); + return new PermutationFeatureImportanceOutput { Metrics = result }; + } + } + + internal sealed class PermutationFeatureImportanceOutput + { + [TlcModule.Output(Desc = "The PFI metrics")] + public IDataView Metrics; + } + + internal sealed class PermutationFeatureImportanceArguments : TransformInputBase + { + [Argument(ArgumentType.Required, HelpText = "The path to the model file", ShortName = "path", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] + public PredictorModel PredictorModel; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Use feature weights to pre-filter features", ShortName = "usefw", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] + public bool UseFeatureWeightFilter = false; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of examples to evaluate on", ShortName = "numexamples", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] + public int? NumberOfExamplesToUse = null; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The number of permutations to perform", ShortName = "permutations", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] + public int PermutationCount = 1; + } + + internal static class PermutationFeatureImportanceUtils + { + internal static IDataView GetMetrics( + IHostEnvironment env, + IPredictor predictor, + RoleMappedData roleMappedData, + PermutationFeatureImportanceArguments input) + { + IDataView result; + if (predictor.PredictionKind == PredictionKind.BinaryClassification) + result = GetBinaryMetrics(env, predictor, roleMappedData, input); + else if (predictor.PredictionKind == PredictionKind.MulticlassClassification) + result = GetMulticlassMetrics(env, predictor, roleMappedData, input); + else if (predictor.PredictionKind == PredictionKind.Regression) + result = GetRegressionMetrics(env, predictor, roleMappedData, input); + else if (predictor.PredictionKind == PredictionKind.Ranking) + result = GetRankingMetrics(env, predictor, roleMappedData, input); + else + throw Contracts.Except( + "Unsupported predictor type. Predictor must be binary classifier, " + + "multiclass classifier, regressor, or ranker."); + + return result; + } + + private static IDataView GetBinaryMetrics( + IHostEnvironment env, + IPredictor predictor, + RoleMappedData roleMappedData, + PermutationFeatureImportanceArguments input) + { + var roles = roleMappedData.Schema.GetColumnRoleNames(); + var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; + var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; + var pred = new BinaryPredictionTransformer>( + env, predictor as IPredictorProducing, roleMappedData.Data.Schema, featureColumnName); + var binaryCatalog = new BinaryClassificationCatalog(env); + var permutationMetrics = binaryCatalog + .PermutationFeatureImportance(pred, + roleMappedData.Data, + labelColumnName: labelColumnName, + useFeatureWeightFilter: input.UseFeatureWeightFilter, + numberOfExamplesToUse: input.NumberOfExamplesToUse, + permutationCount: input.PermutationCount); + + var slotNames = GetSlotNames(roleMappedData.Schema); + Contracts.Assert(slotNames.Length == permutationMetrics.Length, + "Mismatch between number of feature slots and number of features permuted."); + + List metrics = new List(); + for (int i = 0; i < permutationMetrics.Length; i++) + { + if (string.IsNullOrWhiteSpace(slotNames[i])) + continue; + var pMetric = permutationMetrics[i]; + metrics.Add(new BinaryMetrics + { + FeatureName = slotNames[i], + AreaUnderRocCurve = pMetric.AreaUnderRocCurve.Mean, + AreaUnderRocCurveStdErr = pMetric.AreaUnderRocCurve.StandardError, + Accuracy = pMetric.Accuracy.Mean, + AccuracyStdErr = pMetric.Accuracy.StandardError, + PositivePrecision = pMetric.PositivePrecision.Mean, + PositivePrecisionStdErr = pMetric.PositivePrecision.StandardError, + PositiveRecall = pMetric.PositiveRecall.Mean, + PositiveRecallStdErr = pMetric.PositiveRecall.StandardError, + NegativePrecision = pMetric.NegativePrecision.Mean, + NegativePrecisionStdErr = pMetric.NegativePrecision.StandardError, + NegativeRecall = pMetric.NegativeRecall.Mean, + NegativeRecallStdErr = pMetric.NegativeRecall.StandardError, + F1Score = pMetric.F1Score.Mean, + F1ScoreStdErr = pMetric.F1Score.StandardError, + AreaUnderPrecisionRecallCurve = pMetric.AreaUnderPrecisionRecallCurve.Mean, + AreaUnderPrecisionRecallCurveStdErr = pMetric.AreaUnderPrecisionRecallCurve.StandardError + }); + } + + var dataOps = new DataOperationsCatalog(env); + var result = dataOps.LoadFromEnumerable(metrics); + return result; + } + + private static IDataView GetMulticlassMetrics( + IHostEnvironment env, + IPredictor predictor, + RoleMappedData roleMappedData, + PermutationFeatureImportanceArguments input) + { + var roles = roleMappedData.Schema.GetColumnRoleNames(); + var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; + var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; + var pred = new MulticlassPredictionTransformer>>( + env, predictor as IPredictorProducing>, roleMappedData.Data.Schema, featureColumnName, labelColumnName); + var multiclassCatalog = new MulticlassClassificationCatalog(env); + var permutationMetrics = multiclassCatalog + .PermutationFeatureImportance(pred, + roleMappedData.Data, + labelColumnName: labelColumnName, + useFeatureWeightFilter: input.UseFeatureWeightFilter, + numberOfExamplesToUse: input.NumberOfExamplesToUse, + permutationCount: input.PermutationCount); + + var slotNames = GetSlotNames(roleMappedData.Schema); + Contracts.Assert(slotNames.Length == permutationMetrics.Length, + "Mismatch between number of feature slots and number of features permuted."); + + List metrics = new List(); + for (int i = 0; i < permutationMetrics.Length; i++) + { + if (string.IsNullOrWhiteSpace(slotNames[i])) + continue; + var pMetric = permutationMetrics[i]; + metrics.Add(new MulticlassMetrics + { + FeatureName = slotNames[i], + MacroAccuracy = pMetric.MacroAccuracy.Mean, + MacroAccuracyStdErr = pMetric.MacroAccuracy.StandardError, + MicroAccuracy = pMetric.MicroAccuracy.Mean, + MicroAccuracyStdErr = pMetric.MicroAccuracy.StandardError, + LogLoss = pMetric.LogLoss.Mean, + LogLossStdErr = pMetric.LogLoss.StandardError, + LogLossReduction = pMetric.LogLossReduction.Mean, + LogLossReductionStdErr = pMetric.LogLossReduction.StandardError, + TopKAccuracy = pMetric.TopKAccuracy.Mean, + TopKAccuracyStdErr = pMetric.TopKAccuracy.StandardError, + PerClassLogLoss = pMetric.PerClassLogLoss.Select(x => x.Mean).ToArray(), + PerClassLogLossStdErr = pMetric.PerClassLogLoss.Select(x => x.StandardError).ToArray() + }); ; + } + + // Convert unknown size vectors to known size. + var metric = metrics.First(); + SchemaDefinition schema = SchemaDefinition.Create(typeof(MulticlassMetrics)); + ConvertVectorToKnownSize(nameof(metric.PerClassLogLoss), metric.PerClassLogLoss.Length, ref schema); + ConvertVectorToKnownSize(nameof(metric.PerClassLogLossStdErr), metric.PerClassLogLossStdErr.Length, ref schema); + + var dataOps = new DataOperationsCatalog(env); + var result = dataOps.LoadFromEnumerable(metrics, schema); + return result; + } + + private static IDataView GetRegressionMetrics( + IHostEnvironment env, + IPredictor predictor, + RoleMappedData roleMappedData, + PermutationFeatureImportanceArguments input) + { + var roles = roleMappedData.Schema.GetColumnRoleNames(); + var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; + var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; + var pred = new RegressionPredictionTransformer>( + env, predictor as IPredictorProducing, roleMappedData.Data.Schema, featureColumnName); + var regressionCatalog = new RegressionCatalog(env); + var permutationMetrics = regressionCatalog + .PermutationFeatureImportance(pred, + roleMappedData.Data, + labelColumnName: labelColumnName, + useFeatureWeightFilter: input.UseFeatureWeightFilter, + numberOfExamplesToUse: input.NumberOfExamplesToUse, + permutationCount: input.PermutationCount); + + var slotNames = GetSlotNames(roleMappedData.Schema); + Contracts.Assert(slotNames.Length == permutationMetrics.Length, + "Mismatch between number of feature slots and number of features permuted."); + + List metrics = new List(); + for (int i = 0; i < permutationMetrics.Length; i++) + { + if (string.IsNullOrWhiteSpace(slotNames[i])) + continue; + var pMetric = permutationMetrics[i]; + metrics.Add(new RegressionMetrics + { + FeatureName = slotNames[i], + MeanAbsoluteError = pMetric.MeanAbsoluteError.Mean, + MeanAbsoluteErrorStdErr = pMetric.MeanAbsoluteError.StandardError, + MeanSquaredError = pMetric.MeanSquaredError.Mean, + MeanSquaredErrorStdErr = pMetric.MeanSquaredError.StandardError, + RootMeanSquaredError = pMetric.RootMeanSquaredError.Mean, + RootMeanSquaredErrorStdErr = pMetric.RootMeanSquaredError.StandardError, + LossFunction = pMetric.LossFunction.Mean, + LossFunctionStdErr = pMetric.LossFunction.StandardError, + RSquared = pMetric.RSquared.Mean, + RSquaredStdErr = pMetric.RSquared.StandardError + }); + } + + var dataOps = new DataOperationsCatalog(env); + var result = dataOps.LoadFromEnumerable(metrics); + return result; + } + + private static IDataView GetRankingMetrics( + IHostEnvironment env, + IPredictor predictor, + RoleMappedData roleMappedData, + PermutationFeatureImportanceArguments input) + { + var roles = roleMappedData.Schema.GetColumnRoleNames(); + var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; + var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; + var groupIdColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Group.Value).First().Value; + var pred = new RankingPredictionTransformer>( + env, predictor as IPredictorProducing, roleMappedData.Data.Schema, featureColumnName); + var rankingCatalog = new RankingCatalog(env); + var permutationMetrics = rankingCatalog + .PermutationFeatureImportance(pred, + roleMappedData.Data, + labelColumnName: labelColumnName, + rowGroupColumnName: groupIdColumnName, + useFeatureWeightFilter: input.UseFeatureWeightFilter, + numberOfExamplesToUse: input.NumberOfExamplesToUse, + permutationCount: input.PermutationCount); + + var slotNames = GetSlotNames(roleMappedData.Schema); + Contracts.Assert(slotNames.Length == permutationMetrics.Length, + "Mismatch between number of feature slots and number of features permuted."); + + List metrics = new List(); + for (int i = 0; i < permutationMetrics.Length; i++) + { + if (string.IsNullOrWhiteSpace(slotNames[i])) + continue; + var pMetric = permutationMetrics[i]; + metrics.Add(new RankingMetrics + { + FeatureName = slotNames[i], + DiscountedCumulativeGains = pMetric.DiscountedCumulativeGains.Select(x => x.Mean).ToArray(), + DiscountedCumulativeGainsStdErr = pMetric.DiscountedCumulativeGains.Select(x => x.StandardError).ToArray(), + NormalizedDiscountedCumulativeGains = pMetric.NormalizedDiscountedCumulativeGains.Select(x => x.Mean).ToArray(), + NormalizedDiscountedCumulativeGainsStdErr = pMetric.NormalizedDiscountedCumulativeGains.Select(x => x.StandardError).ToArray() + }); + } + + // Convert unknown size vectors to known size. + var metric = metrics.First(); + SchemaDefinition schema = SchemaDefinition.Create(typeof(RankingMetrics)); + ConvertVectorToKnownSize(nameof(metric.DiscountedCumulativeGains), metric.DiscountedCumulativeGains.Length, ref schema); + ConvertVectorToKnownSize(nameof(metric.NormalizedDiscountedCumulativeGains), metric.NormalizedDiscountedCumulativeGains.Length, ref schema); + ConvertVectorToKnownSize(nameof(metric.DiscountedCumulativeGainsStdErr), metric.DiscountedCumulativeGainsStdErr.Length, ref schema); + ConvertVectorToKnownSize(nameof(metric.NormalizedDiscountedCumulativeGainsStdErr), metric.NormalizedDiscountedCumulativeGainsStdErr.Length, ref schema); + + var dataOps = new DataOperationsCatalog(env); + var result = dataOps.LoadFromEnumerable(metrics, schema); + return result; + } + + private static string[] GetSlotNames(RoleMappedSchema schema) + { + VBuffer> slots = default; + schema.Feature.Value.GetSlotNames(ref slots); + var slotValues = slots.DenseValues(); + + List slotNames = new List(); + foreach (var value in slotValues) + { + slotNames.Add(value.ToString()); + }; + + return slotNames.ToArray(); + } + + private static void ConvertVectorToKnownSize(string metricName, int size, ref SchemaDefinition schema) + { + var type = ((VectorDataViewType)schema[metricName].ColumnType).ItemType; + schema[metricName].ColumnType = new VectorDataViewType(type, size); + } + + private class BinaryMetrics + { + public string FeatureName { get; set; } + + public double AreaUnderRocCurve { get; set; } + + public double AreaUnderRocCurveStdErr { get; set; } + + public double Accuracy { get; set; } + + public double AccuracyStdErr { get; set; } + + public double PositivePrecision { get; set; } + + public double PositivePrecisionStdErr { get; set; } + + public double PositiveRecall { get; set; } + + public double PositiveRecallStdErr { get; set; } + + public double NegativePrecision { get; set; } + + public double NegativePrecisionStdErr { get; set; } + + public double NegativeRecall { get; set; } + + public double NegativeRecallStdErr { get; set; } + + public double F1Score { get; set; } + + public double F1ScoreStdErr { get; set; } + + public double AreaUnderPrecisionRecallCurve { get; set; } + + public double AreaUnderPrecisionRecallCurveStdErr { get; set; } + } + + private class MulticlassMetrics + { + public string FeatureName { get; set; } + + public double MacroAccuracy { get; set; } + + public double MacroAccuracyStdErr { get; set; } + + public double MicroAccuracy { get; set; } + + public double MicroAccuracyStdErr { get; set; } + + public double LogLoss { get; set; } + + public double LogLossStdErr { get; set; } + + public double LogLossReduction { get; set; } + + public double LogLossReductionStdErr { get; set; } + + public double TopKAccuracy { get; set; } + + public double TopKAccuracyStdErr { get; set; } + + public double[] PerClassLogLoss { get; set; } + + public double[] PerClassLogLossStdErr { get; set; } + } + + private class RegressionMetrics + { + public string FeatureName { get; set; } + + public double MeanAbsoluteError { get; set; } + + public double MeanAbsoluteErrorStdErr { get; set; } + + public double MeanSquaredError { get; set; } + + public double MeanSquaredErrorStdErr { get; set; } + + public double RootMeanSquaredError { get; set; } + + public double RootMeanSquaredErrorStdErr { get; set; } + + public double LossFunction { get; set; } + + public double LossFunctionStdErr { get; set; } + + public double RSquared { get; set; } + + public double RSquaredStdErr { get; set; } + } + + private class RankingMetrics + { + public string FeatureName { get; set; } + + public double[] DiscountedCumulativeGains { get; set; } + + public double[] DiscountedCumulativeGainsStdErr { get; set; } + + public double[] NormalizedDiscountedCumulativeGains { get; set; } + + public double[] NormalizedDiscountedCumulativeGainsStdErr { get; set; } + } + } +} diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs index a860a3023b..07b9c8f435 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs @@ -206,7 +206,7 @@ private static BinaryClassificationMetrics BinaryClassifierDelta( /// ]]> /// /// - /// The clustering catalog. + /// The multiclass classification catalog. /// The model on which to evaluate feature importance. /// The evaluation data set. /// Label column name. The column data must be . @@ -291,7 +291,7 @@ private static MulticlassClassificationMetrics MulticlassClassificationDelta( /// ]]> /// /// - /// The clustering catalog. + /// The ranking catalog. /// The model on which to evaluate feature importance. /// The evaluation data set. /// Label column name. The column data must be or . diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 6288a254ae..4f2bcc426a 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -116,6 +116,7 @@ Transforms.NGramTranslator Produces a bag of counts of n-grams (sequences of con Transforms.NoOperation Does nothing. Microsoft.ML.Data.NopTransform Nop Microsoft.ML.Data.NopTransform+NopInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.OptionalColumnCreator If the source column does not exist after deserialization, create a column with the right type and default values. Microsoft.ML.Transforms.OptionalColumnTransform MakeOptional Microsoft.ML.Transforms.OptionalColumnTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.PcaCalculator PCA is a dimensionality-reduction transform which computes the projection of a numeric vector onto a low-rank subspace. Microsoft.ML.Transforms.PrincipalComponentAnalysisTransformer Calculate Microsoft.ML.Transforms.PrincipalComponentAnalysisTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.PermutationFeatureImportance Permutation Feature Importance (PFI) Microsoft.ML.Transforms.PermutationFeatureImportanceEntryPoints PermutationFeatureImportance Microsoft.ML.Transforms.PermutationFeatureImportanceArguments Microsoft.ML.Transforms.PermutationFeatureImportanceOutput Transforms.PredictedLabelColumnOriginalValueConverter Transforms a predicted label column to its original values, unless it is of type bool. Microsoft.ML.EntryPoints.FeatureCombiner ConvertPredictedLabel Microsoft.ML.EntryPoints.FeatureCombiner+PredictedLabelInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.RandomNumberGenerator Adds a column with a generated number sequence. Microsoft.ML.Transforms.RandomNumberGenerator Generate Microsoft.ML.Transforms.GenerateNumberTransform+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.RowRangeFilter Filters a dataview on a column of type Single, Double or Key (contiguous). Keeps the values that are in the specified min/max range. NaNs are always filtered out. If the input is a Key type, the min/max are considered percentages of the number of values. Microsoft.ML.EntryPoints.SelectRows FilterByRange Microsoft.ML.Transforms.RangeFilter+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index a68890fe6f..c8e6d6e55c 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -21710,6 +21710,79 @@ "ITransformOutput" ] }, + { + "Name": "Transforms.PermutationFeatureImportance", + "Desc": "Permutation Feature Importance (PFI)", + "FriendlyName": "PFI", + "ShortName": "PFI", + "Inputs": [ + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "PredictorModel", + "Type": "PredictorModel", + "Desc": "The path to the model file", + "Aliases": [ + "path" + ], + "Required": true, + "SortOrder": 150.0, + "IsNullable": false + }, + { + "Name": "UseFeatureWeightFilter", + "Type": "Bool", + "Desc": "Use feature weights to pre-filter features", + "Aliases": [ + "usefw" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": false + }, + { + "Name": "NumberOfExamplesToUse", + "Type": "Int", + "Desc": "Limit the number of examples to evaluate on", + "Aliases": [ + "numexamples" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "PermutationCount", + "Type": "Int", + "Desc": "The number of permutations to perform", + "Aliases": [ + "permutations" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": 1 + } + ], + "Outputs": [ + { + "Name": "Metrics", + "Type": "DataView", + "Desc": "The PFI metrics" + } + ], + "InputKind": [ + "ITransformInput" + ] + }, { "Name": "Transforms.PredictedLabelColumnOriginalValueConverter", "Desc": "Transforms a predicted label column to its original values, unless it is of type bool.", diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index e9727d13a7..87a0c61f1c 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -612,11 +612,746 @@ public void EntryPointExecGraphCommand() cmd.Run(); } + [Fact] + public void BinaryPermutationFeatureImportance() + { + var inputDataPath = GetDataPath("adult.tiny.with-schema.txt"); + var outputDataPath = DeleteOutputPath("binary_pfi_metrics.idv"); + + string inputGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file': '{0}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'education', + 'Source': 'education' + }} + ], + 'Data': '$data' + }}, + 'Name': 'Transforms.CategoricalOneHotVectorizer', + 'Outputs': {{ + 'Model': '$output_model1', + 'OutputData': '$output_data1' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + 'Label' + ], + 'Data': '$output_data1' + }}, + 'Name': 'Transforms.OptionalColumnCreator', + 'Outputs': {{ + 'Model': '$output_model2', + 'OutputData': '$output_data2' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data2', + 'LabelColumn': 'Label', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.LabelColumnKeyBooleanConverter', + 'Outputs': {{ + 'Model': '$output_model3', + 'OutputData': '$output_data3' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data3', + 'Features': [ + 'age', + 'education' + ] + }}, + 'Name': 'Transforms.FeatureCombiner', + 'Outputs': {{ + 'Model': '$output_model4', + 'OutputData': '$output_data4' + }} + }}, + {{ + 'Inputs': {{ + 'MaximumNumberOfIterations': 1, + 'NumThreads': 1, + 'TrainingData': '$output_data4' + }}, + 'Name': 'Trainers.LogisticRegressionBinaryClassifier', + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }} + }}, + {{ + 'Inputs': {{ + 'PredictorModel': '$predictor_model', + 'TransformModels': [ + '$output_model1', + '$output_model2', + '$output_model3', + '$output_model4' + ] + }}, + 'Name': 'Transforms.ManyHeterogeneousModelCombiner', + 'Outputs': {{ + 'PredictorModel': '$output_model' + }} + }}, + {{ + 'Name': 'Transforms.PermutationFeatureImportance', + 'Inputs': {{ + 'Data': '$data', + 'PredictorModel': '$output_model', + 'PermutationCount': 5 + }}, + 'Outputs': {{ + 'Metrics': '$output_data' + }} + }} + ], + 'Outputs': {{ + 'output_data': '{1}' + }} + }}", EscapePath(inputDataPath), EscapePath(outputDataPath)); + + var jsonPath = DeleteOutputPath("graph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); + + var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + var cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); + + var mlContext = new MLContext(); + var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("AreaUnderRocCurve")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("Accuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PositivePrecision")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PositiveRecall")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("NegativePrecision")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("NegativeRecall")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("F1Score")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("AreaUnderPrecisionRecallCurve")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("AreaUnderRocCurveStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("AccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PositivePrecisionStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PositiveRecallStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("NegativePrecisionStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("NegativeRecallStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("F1ScoreStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("AreaUnderPrecisionRecallCurveStdErr")); + } + + [Fact] + public void MulticlassPermutationFeatureImportance() + { + var inputDataPath = GetDataPath("adult.tiny.with-schema.txt"); + var outputDataPath = DeleteOutputPath("mc_pfi_metrics.idv"); + + string inputGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file': '{0}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'education', + 'Source': 'education' + }} + ], + 'Data': '$data', + }}, + 'Name': 'Transforms.CategoricalOneHotVectorizer', + 'Outputs': {{ + 'Model': '$output_model1', + 'OutputData': '$output_data1' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + 'Label' + ], + 'Data': '$output_data1' + }}, + 'Name': 'Transforms.OptionalColumnCreator', + 'Outputs': {{ + 'Model': '$output_model2', + 'OutputData': '$output_data2' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data2', + 'LabelColumn': 'Label', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.LabelColumnKeyBooleanConverter', + 'Outputs': {{ + 'Model': '$output_model3', + 'OutputData': '$output_data3' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data3', + 'Features': [ + 'age', + 'education' + ] + }}, + 'Name': 'Transforms.FeatureCombiner', + 'Outputs': {{ + 'Model': '$output_model4', + 'OutputData': '$output_data4' + }} + }}, + {{ + 'Inputs': {{ + 'TrainingData': '$output_data4', + 'NumThreads': 1, + 'MaxIterations': 1 + }}, + 'Name': 'Trainers.StochasticDualCoordinateAscentClassifier', + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }} + }}, + {{ + 'Inputs': {{ + 'PredictorModel': '$predictor_model', + 'TransformModels': [ + '$output_model1', + '$output_model2', + '$output_model3', + '$output_model4' + ] + }}, + 'Name': 'Transforms.ManyHeterogeneousModelCombiner', + 'Outputs': {{ + 'PredictorModel': '$output_model' + }} + }}, + {{ + 'Name': 'Transforms.PermutationFeatureImportance', + 'Inputs': {{ + 'Data': '$data', + 'PredictorModel': '$output_model', + 'PermutationCount': 5 + }}, + 'Outputs': {{ + 'Metrics': '$output_data' + }} + }} + ], + 'Outputs': {{ + 'output_data': '{1}' + }} + }}", EscapePath(inputDataPath), EscapePath(outputDataPath)); + + var jsonPath = DeleteOutputPath("graph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); + + var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + var cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); + + var mlContext = new MLContext(); + var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MacroAccuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MicroAccuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLoss")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossReduction")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("TopKAccuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PerClassLogLoss")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MacroAccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MicroAccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossReductionStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("TopKAccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PerClassLogLossStdErr")); + } + + [Fact] + public void MulticlassPermutationFeatureImportanceWithKeyToValue() + { + var inputData = GetDataPath("adult.tiny.with-schema.txt"); + var outputDataPath = DeleteOutputPath("mc_ktv_pfi_metrics.idv"); + + string inputGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file': '{0}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'education', + 'Source': 'education' + }} + ], + 'Data': '$data', + }}, + 'Name': 'Transforms.CategoricalOneHotVectorizer', + 'Outputs': {{ + 'Model': '$output_model1', + 'OutputData': '$output_data1' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + 'Label' + ], + 'Data': '$output_data1' + }}, + 'Name': 'Transforms.OptionalColumnCreator', + 'Outputs': {{ + 'Model': '$output_model2', + 'OutputData': '$output_data2' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data2', + 'LabelColumn': 'Label', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.LabelColumnKeyBooleanConverter', + 'Outputs': {{ + 'Model': '$output_model3', + 'OutputData': '$output_data3' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data3', + 'Features': [ + 'age', + 'education' + ] + }}, + 'Name': 'Transforms.FeatureCombiner', + 'Outputs': {{ + 'Model': '$output_model4', + 'OutputData': '$output_data4' + }} + }}, + {{ + 'Inputs': {{ + 'TrainingData': '$output_data4', + 'NumThreads': 1, + 'MaxIterations': 1 + }}, + 'Name': 'Trainers.StochasticDualCoordinateAscentClassifier', + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data4', + 'Column': [ + {{ + 'Name': 'Label', + 'Source': 'Label' + }} + ], + }}, + 'Name': 'Transforms.TextToKeyConverter', + 'Outputs': {{ + 'Model': '$output_model5', + 'OutputData': '$output_data5' + }} + }}, + {{ + 'Inputs': {{ + 'PredictorModel': '$predictor_model', + 'TransformModels': [ + '$output_model1', + '$output_model2', + '$output_model3', + '$output_model4', + '$output_model5' + ] + }}, + 'Name': 'Transforms.ManyHeterogeneousModelCombiner', + 'Outputs': {{ + 'PredictorModel': '$output_model' + }} + }}, + {{ + 'Name': 'Transforms.PermutationFeatureImportance', + 'Inputs': {{ + 'Data': '$data', + 'PredictorModel': '$output_model', + 'PermutationCount': 5 + }}, + 'Outputs': {{ + 'Metrics': '$output_data' + }} + }} + ], + 'Outputs': {{ + 'output_data': '{1}' + }} + }}", EscapePath(inputData), EscapePath(outputDataPath)); + + var jsonPath = DeleteOutputPath("graph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); + + var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + var cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); + + var mlContext = new MLContext(); + var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MacroAccuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MicroAccuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLoss")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossReduction")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("TopKAccuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PerClassLogLoss")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MacroAccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MicroAccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossReductionStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("TopKAccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PerClassLogLossStdErr")); + } + + [Fact] + public void RegressionPermutationFeatureImportance() + { + var inputDataPath = GetDataPath("adult.tiny.with-schema.txt"); + var outputDataPath = DeleteOutputPath("reg_pfi_metrics.idv"); + + string inputGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file': '{0}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'education', + 'Source': 'education' + }} + ], + 'Data': '$data', + }}, + 'Name': 'Transforms.CategoricalOneHotVectorizer', + 'Outputs': {{ + 'Model': '$output_model1', + 'OutputData': '$output_data1' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + 'Label' + ], + 'Data': '$output_data1' + }}, + 'Name': 'Transforms.OptionalColumnCreator', + 'Outputs': {{ + 'Model': '$output_model2', + 'OutputData': '$output_data2' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data2', + 'LabelColumn': 'Label' + }}, + 'Name': 'Transforms.LabelToFloatConverter', + 'Outputs': {{ + 'Model': '$output_model3', + 'OutputData': '$output_data3' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data3', + 'Features': [ + 'age', + 'education' + ] + }}, + 'Name': 'Transforms.FeatureCombiner', + 'Outputs': {{ + 'Model': '$output_model4', + 'OutputData': '$output_data4' + }} + }}, + {{ + 'Inputs': {{ + 'TrainingData': '$output_data4', + 'NumThreads': 1, + 'MaxIterations': 1 + }}, + 'Name': 'Trainers.StochasticDualCoordinateAscentRegressor', + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }} + }}, + {{ + 'Inputs': {{ + 'PredictorModel': '$predictor_model', + 'TransformModels': [ + '$output_model1', + '$output_model2', + '$output_model3', + '$output_model4' + ] + }}, + 'Name': 'Transforms.ManyHeterogeneousModelCombiner', + 'Outputs': {{ + 'PredictorModel': '$output_model' + }} + }}, + {{ + 'Name': 'Transforms.PermutationFeatureImportance', + 'Inputs': {{ + 'Data': '$data', + 'PredictorModel': '$output_model', + 'PermutationCount': 5 + }}, + 'Outputs': {{ + 'Metrics': '$output_data' + }} + }} + ], + 'Outputs': {{ + 'output_data': '{1}' + }} + }}", EscapePath(inputDataPath), EscapePath(outputDataPath)); + + var jsonPath = DeleteOutputPath("graph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); + + var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + var cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); + + var mlContext = new MLContext(); + var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MeanAbsoluteError")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MeanSquaredError")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("RootMeanSquaredError")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LossFunction")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("RSquared")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MeanAbsoluteErrorStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MeanSquaredErrorStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("RootMeanSquaredErrorStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LossFunctionStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("RSquaredStdErr")); + } + + [Fact] + public void RankingPermutationFeatureImportance() + { + var inputData = GetDataPath("adult.tiny.with-schema.txt"); + var outputDataPath = DeleteOutputPath("rank_pfi_metrics.idv"); + + string inputGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file': '{0}' + }}, + 'Nodes': [ + {{ + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file' + }}, + 'Name': 'Data.CustomTextLoader', + 'Outputs': {{ + 'Data': '$input_data' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'Workclass', + 'Source': 'Workclass' + }} + ], + 'Data': '$input_data', + 'MaxNumTerms': 1000000, + 'Sort': 'ByOccurrence', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.TextToKeyConverter', + 'Outputs': {{ + 'Model': '$output_model1', + 'OutputData': '$output_data1' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + 'Label' + ], + 'Data': '$output_data1' + }}, + 'Name': 'Transforms.OptionalColumnCreator', + 'Outputs': {{ + 'Model': '$output_model2', + 'OutputData': '$output_data2' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data2', + 'LabelColumn': 'Label', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.LabelColumnKeyBooleanConverter', + 'Outputs': {{ + 'Model': '$output_model3', + 'OutputData': '$output_data3' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data3', + 'Features': [ + 'age', + 'education-num', + 'capital-gain' + ] + }}, + 'Name': 'Transforms.FeatureCombiner', + 'Outputs': {{ + 'Model': '$output_model4', + 'OutputData': '$output_data4' + }} + }}, + {{ + 'Inputs': {{ + 'NumberOfTrees': 1, + 'RowGroupColumnName': 'Workclass', + 'TrainingData': '$output_data4', + 'NumberOfLeaves': 2 + }}, + 'Name': 'Trainers.FastTreeRanker', + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }} + }}, + {{ + 'Inputs': {{ + 'PredictorModel': '$predictor_model', + 'TransformModels': [ + '$output_model1', + '$output_model2', + '$output_model3', + '$output_model4' + ] + }}, + 'Name': 'Transforms.ManyHeterogeneousModelCombiner', + 'Outputs': {{ + 'PredictorModel': '$output_model' + }} + }}, + {{ + 'Name': 'Transforms.PermutationFeatureImportance', + 'Inputs': {{ + 'Data': '$input_data', + 'PredictorModel': '$output_model', + 'PermutationCount': 5 + }}, + 'Outputs': {{ + 'Metrics': '$output_data' + }} + }} + ], + 'Outputs': {{ + 'output_data': '{1}' + }} + }}", EscapePath(inputData), EscapePath(outputDataPath)); + + var jsonPath = DeleteOutputPath("trainingGraph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); + + var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + var cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); + + var mlContext = new MLContext(); + var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("DiscountedCumulativeGains")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("NormalizedDiscountedCumulativeGains")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("DiscountedCumulativeGainsStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("NormalizedDiscountedCumulativeGainsStdErr")); + } + [Fact] public void ScoreTransformerChainModel() { var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); - var modelPath = DeleteOutputPath("model.zip"); + var modelPath = DeleteOutputPath("score_model.zip"); var outputDataPath = DeleteOutputPath("scored.idv"); var mlContext = new MLContext();