-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
133 lines (123 loc) · 2.98 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import numpy as np
from jax import numpy as jnp
import jax
def load_pytree_from_dir(base_dir: str) -> dict:
""" Load a pytree from a directory
Args:
base_dir (str): path to the directory
Returns:
dict: pytree
"""
tree_def = np.load(os.path.join(base_dir, 'tree_def.npy'), allow_pickle=True).item()
flat = []
i = 0
while True:
array_path = os.path.join(base_dir, f'array_{i}.npy')
if not os.path.exists(array_path):
break
flat.append(jnp.array(np.load(array_path)))
i += 1
return jax.tree_util.tree_unflatten(tree_def, flat)
id_to_token = [
"<unk>",
"<pad>",
"<mask>",
"<cls>",
"<eos>",
"<bos>",
"A",
"R",
"N",
"D",
"C",
"Q",
"E",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
"X",
"B",
"Z",
"J",
"U",
"O",
]
def sample_to_string(sample: jax.Array) -> str:
""" Convert a sample to a string
Args:
sample (jax.Array): sample
Returns:
str: string representation of the sample
"""
string = ""
for i in sample:
if i == 4:
break
if i > 5:
string += id_to_token[i]
return string
def string_to_sample(sample: str, length: int) -> jax.Array:
""" Convert a string to a sample
Args:
sample (str): string representation of the sample
length (int): length of the sample
Returns:
jax.Array: sample of shape (length,)
"""
result = id_to_token.index("<pad>") * np.ones(length, dtype=jnp.int32)
for i, token in enumerate(sample):
result[i] = id_to_token.index(token)
# End with EOS
if len(sample) < length:
result[len(sample)] = 4
return jnp.array(result)
def count_sequential_repeats(string: str) -> dict:
""" Count the number of sequential repeats in a string
Args:
string (str): input string
Returns:
dict: dictionary with the counts of sequential repeats
"""
if not string:
return {}
repeats = {}
current_char = string[0]
count = 1
for char in string[1:]:
if char == current_char:
count += 1
else:
if count > 3:
if current_char in repeats:
repeats[current_char].append(count)
else:
repeats[current_char] = [count]
current_char = char
count = 1
if count > 3:
if current_char in repeats:
repeats[current_char].append(count)
else:
repeats[current_char] = [count]
return repeats
def repetition_score(string: str) -> float:
""" Compute the repetition score of a string
Args:
string (str): input string
Returns:
float: repetition score
"""
counts = count_sequential_repeats(string)
per_char_score = [np.sum(counts[char]) for char in counts.keys()]
return np.sum(per_char_score) / len(string)