-
Notifications
You must be signed in to change notification settings - Fork 0
/
perspective_ranker_test.py
143 lines (110 loc) · 3.76 KB
/
perspective_ranker_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import pytest
from fastapi.encoders import jsonable_encoder
from fastapi.testclient import TestClient
from aioresponses import aioresponses
import perspective_ranker
from ranking_challenge.fake import fake_request
@pytest.fixture
def app():
app = perspective_ranker.app
yield app
@pytest.fixture
def client(app):
return TestClient(app)
def api_response(attributes):
api_response = {"attributeScores": {}}
for attr in attributes:
api_response["attributeScores"][attr] = {
"summaryScore": {
"value": 0.5,
}
}
return api_response
def test_rank(client):
comments = fake_request(n_posts=1, n_comments=2)
comments.session.cohort = "perspective_baseline"
with aioresponses() as mocked:
mocked.post(
perspective_ranker.PERSPECTIVE_URL,
payload=api_response(perspective_ranker.perspective_baseline),
repeat=True
)
response = client.post("/rank", json=jsonable_encoder(comments))
# this mock does not have a call_count method for some reason
call_count = len(list(mocked.requests.values())[0])
assert call_count == 3 # not retrying successful requests
assert response.status_code == 200
result = response.json()
assert len(result["ranked_ids"]) == 3
def test_rank_no_score(client):
comments = fake_request(n_posts=1, n_comments=2)
comments.session.cohort = "perspective_baseline"
with aioresponses() as mocked:
mocked.post(
perspective_ranker.PERSPECTIVE_URL,
payload={},
repeat=True
)
response = client.post("/rank", json=jsonable_encoder(comments))
assert response.status_code == 200
result = response.json()
assert len(result["ranked_ids"]) == 3
def test_arm_selection():
rank = perspective_ranker.PerspectiveRanker()
comments = fake_request(n_posts=1, n_comments=2)
comments.session.cohort = "perspective_baseline"
result = rank.arm_selection(comments)
assert result == perspective_ranker.perspective_baseline
@pytest.mark.asyncio
async def test_score():
rank = perspective_ranker.PerspectiveRanker()
with aioresponses() as mocked:
mocked.post(
perspective_ranker.PERSPECTIVE_URL,
payload=api_response(["TOXICITY"]),
repeat=True
)
result = await rank.score(["TOXICITY"], "Test statement", "test_statement_id")
assert result.attr_scores == [("TOXICITY", 0.5)]
assert result.statement == "Test statement"
assert result.statement_id == "test_statement_id"
def test_arm_sort():
rank = perspective_ranker.PerspectiveRanker()
scored_statements = [
rank.ScoredStatement(
"Test statement 2",
[("TOXICITY", 0.6), ("REASONING_EXPERIMENTAL", 0.2)],
"test_statement_id_2",
True,
0.1,
),
rank.ScoredStatement(
"Test statement",
[("TOXICITY", 0.1), ("REASONING_EXPERIMENTAL", 0.1)],
"test_statement_id_1",
True,
0.1,
),
rank.ScoredStatement(
"Test statement",
[("TOXICITY", 0), ("REASONING_EXPERIMENTAL", 0)],
"test_statement_id_unscorable",
False,
0.1,
),
rank.ScoredStatement(
"Test statement 3",
[("TOXICITY", 0.9), ("REASONING_EXPERIMENTAL", 0.3)],
"test_statement_id_3",
True,
0.1,
),
]
result = rank.arm_sort(perspective_ranker.perspective_toxicity, scored_statements)
assert result["ranked_ids"] == [
"test_statement_id_1",
"test_statement_id_unscorable",
"test_statement_id_2",
"test_statement_id_3",
]