-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmy_datasets.py
43 lines (36 loc) · 1.23 KB
/
my_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import one_hot
class mydatasets(Dataset):
def __init__(self,root_dir):
super(mydatasets, self).__init__()
self.list_image_path=[ os.path.join(root_dir,image_name) for image_name in os.listdir(root_dir)]
self.transforms=transforms.Compose([
transforms.Resize((50,100)),
transforms.ToTensor(),
transforms.Grayscale()
])
def __getitem__(self, index):
image_path = self.list_image_path[index]
img_ = Image.open(image_path)
image_name=image_path.split("\\")[-1]
img_tesor=self.transforms(img_)
img_lable=image_name.split(".")[0]
img_lable=one_hot.text2vec(img_lable)
img_lable=img_lable.view(1,-1)[0]
#print(img_tesor)
#print(img_lable)
return img_tesor,img_lable
def __len__(self):
return self.list_image_path.__len__()
if __name__ == '__main__':
d=mydatasets("./dataset/train")
img,label=d[2]
writer=SummaryWriter("logs")
writer.add_image("img",img,1)
print(img.shape)
writer.close()
#tensorboard --logdir=logs