-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
86 lines (68 loc) · 2.21 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset
class MNISTDataset(Dataset):
@staticmethod
def transform_data(
data: torchvision.datasets,
width: int,
height: int,
min_len: int,
scale_index: torch.Tensor
) -> tuple[torch.Tensor, ...]:
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((width, height)),
#transforms.Lambda(lambda image: scale_index * image)
]
)
x_data = torch.Tensor(size=(2 * min_len, width, height))
y_data = torch.Tensor(size=(2 * min_len,))
zeros_labels = torch.Tensor([0 for i in range(min_len)])
ones_labels = torch.Tensor([1 for i in range(min_len)])
zeros = list(
map(
lambda x: transform(x[0]),
filter(lambda x: x[1] == 0, data)
)
)
ones = list(
map(
lambda x: transform(x[0]),
filter(lambda x: x[1] == 1, data)
)
)
torch.cat([*zeros[:min_len], *ones[:min_len]], out=x_data)
torch.cat([zeros_labels, ones_labels], out=y_data)
return x_data, y_data
def __init__(
self,
load_dir: str,
width: int,
height: int,
min_len: int,
train: bool
):
self.__length = 2 * min_len
self.__pi = 2 * torch.acos(torch.zeros(1)).item()
loaded_data = torchvision.datasets.MNIST(
load_dir,
download=True,
train=train
)
self.x_data, self.y_data = self.transform_data(
loaded_data, width, height, min_len, self.__pi
)
self.x_data = self.x_data.reshape(
shape=(self.__length, 1, width, height)
)
def __getitem__(self, index: int) -> tuple[torch.Tensor, ...]:
"""
returns:
tensor: (1, width, height), tensor: (1,)
"""
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.__length