Skip to content

Commit

Permalink
add unit test for 5PL fit model
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmin committed Feb 1, 2021
1 parent 243f60f commit f03ebd7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
17 changes: 16 additions & 1 deletion python/test/train_test_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from vmaf.config import VmafConfig
from vmaf.core.train_test_model import TrainTestModel, \
LibsvmNusvrTrainTestModel, SklearnRandomForestTrainTestModel, \
MomentRandomForestTrainTestModel, SklearnExtraTreesTrainTestModel, SklearnLinearRegressionTrainTestModel
MomentRandomForestTrainTestModel, SklearnExtraTreesTrainTestModel, \
SklearnLinearRegressionTrainTestModel, Logistic5PLRegressionTrainTestModel
from vmaf.core.noref_feature_extractor import MomentNorefFeatureExtractor
from vmaf.routine import read_dataset
from vmaf.tools.misc import import_python_file
Expand Down Expand Up @@ -309,6 +310,20 @@ def test_train_predict_extratrees(self):
result = model.evaluate(xs, ys)
self.assertAlmostEqual(result['RMSE'], 0.042867322777879642, places=4)

def test_train_logistic_fit_5PL(self):
xs = Logistic5PLRegressionTrainTestModel.get_xs_from_results(self.features, [0, 1, 2, 3, 4, 5], features=['Moment_noref_feature_1st_score'])
ys = Logistic5PLRegressionTrainTestModel.get_ys_from_results(self.features, [0, 1, 2, 3, 4, 5])

xys = {}
xys.update(xs)
xys.update(ys)

model = Logistic5PLRegressionTrainTestModel({'norm_type': 'clip_0to1'}, None)
model.train(xys)
result = model.evaluate(xs, ys)

self.assertAlmostEqual(result['RMSE'], 0.3603374311919728, places=4)


class TrainTestModelWithDisYRawVideoExtractorTest(unittest.TestCase):

Expand Down
12 changes: 7 additions & 5 deletions python/vmaf/core/train_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def _delete(filename, **more):
os.remove(filename)

@classmethod
def get_xs_from_results(cls, results, indexs=None, aggregate=True):
def get_xs_from_results(cls, results, indexs=None, aggregate=True, features=None):
"""
:param results: list of BasicResult, or pandas.DataFrame
:param indexs: indices of results to be used
Expand All @@ -756,8 +756,11 @@ def get_xs_from_results(cls, results, indexs=None, aggregate=True):
# or get_ordered_list_scores_key. Instead, just get the sorted keys
feature_names = results[0].get_ordered_results()

feature_names = list(feature_names)
cls._assert_dimension(feature_names, results)
if features is not None:
feature_names = [f for f in feature_names if f in features]
else:
feature_names = list(feature_names)
cls._assert_dimension(feature_names, results)

# collect results into xs
xs = {}
Expand Down Expand Up @@ -1189,10 +1192,9 @@ def _train(cls, model_param, xys_2d, **kwargs):
del model_param_['num_models']

from scipy.optimize import curve_fit

[[b1, b2, b3, b4, b5], _] = curve_fit(
lambda x, b1, b2, b3, b4, b5: b1 + (0.5 - 1/(1+np.exp(b2*(x-b3))))+b4*x+b5,
np.ravel(xys_2d[:, 1:]),
np.ravel(xys_2d[:, 1]),
np.ravel(xys_2d[:, 0]),
p0=0.5 * np.ones((5,)),
maxfev=20000
Expand Down
8 changes: 7 additions & 1 deletion unittest
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
#!/usr/bin/env sh

PYTHONPATH=python python3 -m unittest discover -v -s python/test/ -p '*_test.py'
if [ -z "$1" ]; then
pattern='*_test.py'
else
pattern="$1"
fi

PYTHONPATH=python python3 -m unittest discover -v -s python/test/ -p $pattern

0 comments on commit f03ebd7

Please sign in to comment.