forked from Netflix/vmaf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bd_rate_calculator.py
133 lines (99 loc) · 4 KB
/
bd_rate_calculator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import numpy as np
from vmaf.tools.interpolation_utils import InterpolationUtils
__copyright__ = "Copyright 2016-2018, Netflix, Inc."
__license__ = "Apache, Version 2.0"
class BDrateCalculator(object):
"""
BD rate calculator. Implementation validated against JCTVC-E137:
http://phenix.it-sudparis.eu/jct/doc_end_user/documents/5_Geneva/wg11/JCTVC-E137-v1.zip
"""
REJECTED_BD_RATE_NO_OVERLAP = -10000
REJECTED_BD_RATE_NON_MONOTONIC = -20000
REJECTED_BD_RATE_MESSED_UP_RATE = -30000
REJECTED_BALANCE_BAD_MEASUREMENTS = -40000
@staticmethod
def _dedup_and_order(set_):
return sorted(list(set(set_)), key=lambda x: x[0])
@classmethod
def CalcBDRate(cls, setA, setB):
# ==== added by zli =======
setA = cls._dedup_and_order(setA)
setB = cls._dedup_and_order(setB)
# ==== added by zli =======
assert not (len(setA) < 4 or len(setB) < 4), \
"Problem with input RD point lists. setA is size {}, " \
"setB is size {}".format(len(setA), len(setB))
if not cls.isCurveMonotonic(setA):
raise AssertionError(cls.REJECTED_BD_RATE_NON_MONOTONIC)
if not cls.isCurveMonotonic(setB):
raise AssertionError(cls.REJECTED_BD_RATE_NON_MONOTONIC)
if not cls.ratesLookOkay(setA):
raise AssertionError(cls.REJECTED_BD_RATE_MESSED_UP_RATE)
if not cls.ratesLookOkay(setB):
raise AssertionError(cls.REJECTED_BD_RATE_MESSED_UP_RATE)
minMainPSNR = setA[0][1]
maxMainPSNR = setA[-1][1]
minHighPSNR = setB[0][1]
maxHighPSNR = setB[-1][1]
minPSNR = max(minMainPSNR, minHighPSNR)
maxPSNR = min(maxMainPSNR, maxHighPSNR)
# no overlap, so mark it in a special way
if minPSNR >= maxPSNR:
return cls.REJECTED_BD_RATE_NO_OVERLAP
vA = cls.bdrint(setA, minPSNR, maxPSNR)
vB = cls.bdrint(setB, minPSNR, maxPSNR)
avg = (vB - vA) / (maxPSNR - minPSNR)
return np.power(10, avg) - 1
@staticmethod
def isCurveMonotonic(set_):
for i in range(len(set_) - 1):
# ==== added by zli =======
if set_[i][0] >= set_[i + 1][0]:
return False
# ==== added by zli =======
# ==== added by zli =======
# if set_[i][1] > set_[i + 1][1]:
if set_[i][1] >= set_[i + 1][1]:
# ==== added by zli =======
return False
return True
@staticmethod
def ratesLookOkay(set_):
for i in range(len(set_)):
if set_[i][0] == 0:
return False
return True
# // BD-rate calculation for arbitrary number (N) points
# // cf. https://www.mathworks.com/moler/interp.pdf, sections 3.3 - 3.4
@staticmethod
def bdrint(rdPointsList, minPSNR, maxPSNR):
N = len(rdPointsList)
log_rate = []
log_dist = []
H = []
delta = []
d = []
c = []
b = []
InterpolationUtils.computeParamsForSegments(rdPointsList, log_rate, log_dist, H, delta, d, c, b, True)
# // cubic function is rate(i) + s*(d(i) + s*(c(i) + s*(b(i))) where s = x - dist(i)
# // or rate(i) + s*d(i) + s*s*c(i) + s*s*s*b(i)
# // primitive is s*rate(i) + s*s*d(i)/2 + s*s*s*c(i)/3 + s*s*s*s*b(i)/4
result = 0.0
for i in range(N - 1):
s0 = log_dist[i]
s1 = log_dist[i + 1]
# // clip s0 to valid range
s0 = max(s0, minPSNR)
s0 = min(s0, maxPSNR)
s0 -= log_dist[i]
# // clip s1 to valid range
s1 = max(s1, minPSNR)
s1 = min(s1, maxPSNR)
s1 -= log_dist[i]
if s1 > s0:
result += (s1 - s0) * log_rate[i]
result += (s1 * s1 - s0 * s0) * d[i] / 2.0
result += (s1 * s1 * s1 - s0 * s0 * s0) * c[i] / 3.0
result += (s1 * s1 * s1 * s1 - s0 * s0 * s0 * s0) * b[i] / 4.0
return result