diff --git a/python/test/train_test_model_test.py b/python/test/train_test_model_test.py index cea491d32..cc6e6ce82 100644 --- a/python/test/train_test_model_test.py +++ b/python/test/train_test_model_test.py @@ -6,7 +6,8 @@ from vmaf.config import VmafConfig from vmaf.core.train_test_model import TrainTestModel, \ LibsvmNusvrTrainTestModel, SklearnRandomForestTrainTestModel, \ - MomentRandomForestTrainTestModel, SklearnExtraTreesTrainTestModel, SklearnLinearRegressionTrainTestModel + MomentRandomForestTrainTestModel, SklearnExtraTreesTrainTestModel, \ + SklearnLinearRegressionTrainTestModel, Logistic5PLRegressionTrainTestModel from vmaf.core.noref_feature_extractor import MomentNorefFeatureExtractor from vmaf.routine import read_dataset from vmaf.tools.misc import import_python_file @@ -309,6 +310,20 @@ def test_train_predict_extratrees(self): result = model.evaluate(xs, ys) self.assertAlmostEqual(result['RMSE'], 0.042867322777879642, places=4) + def test_train_logistic_fit_5PL(self): + xs = Logistic5PLRegressionTrainTestModel.get_xs_from_results(self.features, [0, 1, 2, 3, 4, 5], features=['Moment_noref_feature_1st_score']) + ys = Logistic5PLRegressionTrainTestModel.get_ys_from_results(self.features, [0, 1, 2, 3, 4, 5]) + + xys = {} + xys.update(xs) + xys.update(ys) + + model = Logistic5PLRegressionTrainTestModel({'norm_type': 'clip_0to1'}, None) + model.train(xys) + result = model.evaluate(xs, ys) + + self.assertAlmostEqual(result['RMSE'], 0.3603374311919728, places=4) + class TrainTestModelWithDisYRawVideoExtractorTest(unittest.TestCase): diff --git a/python/vmaf/core/train_test_model.py b/python/vmaf/core/train_test_model.py index 356fc2c5e..4b8332216 100644 --- a/python/vmaf/core/train_test_model.py +++ b/python/vmaf/core/train_test_model.py @@ -740,7 +740,7 @@ def _delete(filename, **more): os.remove(filename) @classmethod - def get_xs_from_results(cls, results, indexs=None, aggregate=True): + def get_xs_from_results(cls, results, indexs=None, aggregate=True, features=None): """ :param results: list of BasicResult, or pandas.DataFrame :param indexs: indices of results to be used @@ -756,8 +756,11 @@ def get_xs_from_results(cls, results, indexs=None, aggregate=True): # or get_ordered_list_scores_key. Instead, just get the sorted keys feature_names = results[0].get_ordered_results() - feature_names = list(feature_names) - cls._assert_dimension(feature_names, results) + if features is not None: + feature_names = [f for f in feature_names if f in features] + else: + feature_names = list(feature_names) + cls._assert_dimension(feature_names, results) # collect results into xs xs = {} @@ -1189,10 +1192,9 @@ def _train(cls, model_param, xys_2d, **kwargs): del model_param_['num_models'] from scipy.optimize import curve_fit - [[b1, b2, b3, b4, b5], _] = curve_fit( lambda x, b1, b2, b3, b4, b5: b1 + (0.5 - 1/(1+np.exp(b2*(x-b3))))+b4*x+b5, - np.ravel(xys_2d[:, 1:]), + np.ravel(xys_2d[:, 1]), np.ravel(xys_2d[:, 0]), p0=0.5 * np.ones((5,)), maxfev=20000 diff --git a/unittest b/unittest index 6ca06c762..fc36d56d1 100755 --- a/unittest +++ b/unittest @@ -1,3 +1,9 @@ #!/usr/bin/env sh -PYTHONPATH=python python3 -m unittest discover -v -s python/test/ -p '*_test.py' +if [ -z "$1" ]; then + pattern='*_test.py' +else + pattern="$1" +fi + +PYTHONPATH=python python3 -m unittest discover -v -s python/test/ -p $pattern