-
Notifications
You must be signed in to change notification settings - Fork 37
/
util.py
201 lines (156 loc) · 5.35 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import re
import json
import hashlib
import shutil
import traceback
import requests
from io import BytesIO
from dataclasses import dataclass
from typing import Optional, List, Any, Union
import torch
def round1(x: float) -> float:
"""Round to 1 decimal place."""
return round(x, 1)
def mean(x: List[float]) -> float:
return sum(x) / len(x)
def count(list, x):
"""Return the number of times `x` appears in `list`."""
return sum(1 for y in list if y == x)
def get_device(index: int = 0) -> torch.device:
"""Try to use the GPU if possible, otherwise, use CPU."""
if torch.cuda.is_available():
return torch.device(f"cuda:{index}")
else:
return torch.device("cpu")
def ensure_directory_exists(path: str):
if not os.path.exists(path):
os.mkdir(path)
def download_file(url: str, filename: str):
"""Download `url` and save the contents to `filename`. Skip if `filename` already exists."""
if not os.path.exists(filename):
print(f"Downloading {url} to {filename}")
response = requests.get(url)
with open(filename, "wb") as f:
shutil.copyfileobj(BytesIO(response.content), f)
def cached(url: str) -> str:
"""Download `url` if needed and return the location of the cached file."""
name = re.sub(r"[^\w_-]+", "_", url)
url_hash = hashlib.md5(url.encode('utf-8')).hexdigest()
path = os.path.join("var", url_hash + "-" + name)
download_file(url, path)
return path
def get_stack(pop_stack: bool = False):
"""
Return the current stack as a string.
if `pop_stack`, then remove the last function.
"""
stack = traceback.extract_stack()
# Start at <module>
i = None
for j, frame in enumerate(stack):
if frame.name == "<module>":
i = j
if i is not None:
stack = stack[i + 1:] # Delete everything up to the last module
stack = stack[:-2] # Remove the current two functions (get_stack and point/figure/etc.)
if pop_stack:
stack = stack[:-1]
stack = [
{
"name": frame.name,
"filename": os.path.basename(frame.filename),
"lineno": frame.lineno,
} \
for frame in stack
]
return stack
def note(message: str, style: Optional[dict] = None, verbatim: bool = False, pop_stack: bool = False):
"""Make a note (bullet point) with `message`."""
print("note:", message)
style = style or {}
if verbatim:
messages = message.split("\n")
style = {
"font-family": "monospace",
"white-space": "pre",
**style
}
else:
messages = [message]
for message in messages:
stack = get_stack(pop_stack=pop_stack)
add_content("addText", [stack, message, style])
def see(obj: Any, pop_stack: bool = False):
"""References `obj` in the code, but don't print anything out."""
print("see:", obj)
if isinstance(obj, str):
message = obj
else:
message = str(obj)
style = {"color": "gray"}
stack = get_stack(pop_stack=pop_stack)
add_content("addText", [stack, message, style])
def image(path: str, style: Optional[dict] = None, width: float = 1.0, pop_stack: bool = False):
"""Show the image at `path`."""
print("image:", path)
style = style or {}
style["width"] = str(width * 100) + "%"
stack = get_stack(pop_stack=pop_stack)
add_content("addImage", [stack, path, style])
# Where the contents of the lecture are written to be displayed via `view.html`.
content_path: Optional[str] = None
def init_content(path: str):
global content_path
content_path = path
# Clear the file
with open(content_path, "w") as f:
pass
def add_content(function_name, args: List[Any]):
assert content_path
line = function_name + "(" + ", ".join(map(json.dumps, args)) + ")"
# Append to the file
with open(content_path, "a") as f:
print(line, file=f)
############################################################
@dataclass(frozen=True)
class Spec:
name: Optional[str] = None
author: Optional[str] = None
organization: Optional[str] = None
date: Optional[str] = None
url: Optional[str] = None
description: Optional[Union[str, List[str]]] = None
references: Optional[List[Any]] = None
@dataclass(frozen=True)
class MethodSpec(Spec):
pass
@dataclass(frozen=True)
class DataSpec(Spec):
num_tokens: Optional[int] = None
vocabulary_size: Optional[int] = None
@dataclass(frozen=True)
class ArchitectureSpec(Spec):
num_parameters: Optional[int] = None
num_layers: Optional[int] = None
dim_model: Optional[int] = None
num_heads: Optional[int] = None
dim_head: Optional[int] = None
description: Optional[str] = None
references: Optional[List[Any]] = None
@dataclass(frozen=True)
class TrainingSpec(Spec):
context_length: Optional[int] = None
batch_size_tokens: Optional[int] = None
learning_rate: Optional[float] = None
weight_decay: Optional[float] = None
optimizer: Optional[str] = None
hardware: Optional[str] = None
num_epochs: Optional[int] = None
num_flops: Optional[int] = None
references: Optional[List[Any]] = None
@dataclass(frozen=True)
class ModelSpec(Spec):
data: Optional[DataSpec] = None
architecture: Optional[ArchitectureSpec] = None
training: Optional[TrainingSpec] = None