-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
22 lines (18 loc) · 818 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import json
from typing import Iterable
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
def write_results(results: Iterable, filename: str):
"""Write results to the given file."""
with open(filename, "w", encoding="utf-8") as f:
for result in results:
f.write(json.dumps(result) + "\n")
def performance_metrics(y_true, y_pred):
"""Calculates performance metrics for predictions given true values."""
return {
"accuracy": accuracy_score(y_true, y_pred),
"precision": precision_score(
y_true, y_pred, average="weighted", zero_division=0
),
"recall": recall_score(y_true, y_pred, average="weighted", zero_division=0),
"f1": f1_score(y_true, y_pred, average="weighted", zero_division=0),
}