-
Notifications
You must be signed in to change notification settings - Fork 2
/
data.py
87 lines (68 loc) · 2.66 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from pathlib import Path
from typing import Any
import torch
import torchvision
from torchvision import transforms
from scaling.core import BaseDataset, BaseDatasetBatch, BaseDatasetItem, broadcast_data
from scaling.core.topology import Topology
class MNISTDatasetItem(BaseDatasetItem):
def __init__(self, input_: Any, target: Any) -> None:
self.input = torch.tensor(input_, dtype=torch.float16)
self.target = torch.tensor(target, dtype=torch.float16)
class MNISTDatasetBatch(BaseDatasetBatch):
def __init__(
self,
inputs: torch.Tensor | None = None,
targets: torch.Tensor | None = None,
):
self.inputs = inputs
self.targets = targets
def only_inputs(self) -> "MNISTDatasetBatch":
return MNISTDatasetBatch(inputs=self.inputs)
def only_targets(self) -> "MNISTDatasetBatch":
return MNISTDatasetBatch(targets=self.targets)
class MNISTDataset(BaseDataset[MNISTDatasetItem, MNISTDatasetBatch, MNISTDatasetBatch]):
def __init__(self, root: Path = Path("./.data"), train: bool = True):
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
)
self.dataset = torchvision.datasets.MNIST(
root=root,
train=train,
transform=transform,
download=True,
)
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, index: int) -> MNISTDatasetItem:
return MNISTDatasetItem(
input_=self.dataset[index][0],
target=self.dataset[index][1],
)
def ident(self) -> str:
return "MNIST"
def set_seed(self, seed: int, shuffle: bool = True) -> None:
return
def collate(self, batch: list[MNISTDatasetItem]) -> MNISTDatasetBatch:
inputs = torch.stack([batch_item.input for batch_item in batch])
targets = torch.stack([batch_item.target for batch_item in batch])
return MNISTDatasetBatch(inputs=inputs, targets=targets)
@staticmethod
def sync_batch_to_model_parallel(
topology: Topology,
batch: MNISTDatasetBatch | None,
) -> MNISTDatasetBatch:
if topology.model_parallel_rank == 0:
assert batch is not None
tensors: list[torch.Tensor | None] = [batch.inputs, batch.targets]
else:
assert batch is None
tensors = [None, None]
broadcast_tensors = broadcast_data(tensors=tensors, dtype=torch.float16, topology=topology)
return MNISTDatasetBatch(
inputs=broadcast_tensors[0],
targets=broadcast_tensors[1],
)