From f90d4303c18bb7fea81e3e96f650b587d63d2773 Mon Sep 17 00:00:00 2001 From: dnddnjs Date: Wed, 10 Oct 2018 08:50:07 -0400 Subject: [PATCH] train model with gpu --- .gitignore | 3 + classification/cifar_data.py | 139 ----------------------------------- classification/train.py | 11 ++- 3 files changed, 10 insertions(+), 143 deletions(-) create mode 100644 .gitignore delete mode 100644 classification/cifar_data.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..619e375 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +data/ +__pycache__/ +*.pth diff --git a/classification/cifar_data.py b/classification/cifar_data.py deleted file mode 100644 index 835ee94..0000000 --- a/classification/cifar_data.py +++ /dev/null @@ -1,139 +0,0 @@ -from PIL import Image -import os -import os.path -import numpy as np -import sys -import pickle - -import torch.utils.data as data -from utils import download_url, check_integrity - - -class CIFAR10(data.Dataset): - base_folder = 'cifar-10-batches-py' - url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" - filename = "cifar-10-python.tar.gz" - tgz_md5 = 'c58f30108f718f92721af3b95e74349a' - train_list = [ - ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], - ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], - ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], - ['data_batch_4', '634d18415352ddfa80567beed471001a'], - ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], - ] - - test_list = [ - ['test_batch', '40351d587109b95175f43aff81a1287e'], - ] - meta = { - 'filename': 'batches.meta', - 'key': 'label_names', - 'md5': '5ff9c542aee3614f3951f8cda6e48888', - } - - def __init__(self, root, train=True, - transform=None, target_transform=None, - download=False): - self.root = os.path.expanduser(root) - self.transform = transform - self.target_transform = target_transform - self.train = train - - if download: - self.download() - - if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') - - if self.train: - downloaded_list = self.train_list - else: - downloaded_list = self.test_list - - self.data = [] - self.targets = [] - - # now load the picked numpy arrays - for file_name, checksum in downloaded_list: - file_path = os.path.join(self.root, self.base_folder, file_name) - with open(file_path, 'rb') as f: - entry = pickle.load(f, encoding='latin1') - self.data.append(entry['data']) - if 'labels' in entry: - self.targets.extend(entry['labels']) - else: - self.targets.extend(entry['fine_labels']) - - self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) - self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC - - self._load_meta() - - def _load_meta(self): - path = os.path.join(self.root, self.base_folder, self.meta['filename']) - if not check_integrity(path, self.meta['md5']): - raise RuntimeError('Dataset metadata file not found or corrupted.' + - ' You can use download=True to download it') - with open(path, 'rb') as infile: - data = pickle.load(infile, encoding='latin1') - self.classes = data[self.meta['key']] - self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} - - def __getitem__(self, index): - """ - Args: - index (int): Index - Returns: - tuple: (image, target) where target is index of the target class. - """ - img, target = self.data[index], self.targets[index] - - # doing this so that it is consistent with all other datasets - # to return a PIL Image - img = Image.fromarray(img) - - if self.transform is not None: - img = self.transform(img) - - if self.target_transform is not None: - target = self.target_transform(target) - - return img, target - - def __len__(self): - return len(self.data) - - def _check_integrity(self): - root = self.root - for fentry in (self.train_list + self.test_list): - filename, md5 = fentry[0], fentry[1] - fpath = os.path.join(root, self.base_folder, filename) - if not check_integrity(fpath, md5): - return False - return True - - def download(self): - import tarfile - - if self._check_integrity(): - print('Files already downloaded and verified') - return - - download_url(self.url, self.root, self.filename, self.tgz_md5) - - # extract file - with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: - tar.extractall(path=self.root) - - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - tmp = 'train' if self.train is True else 'test' - fmt_str += ' Split: {}\n'.format(tmp) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str diff --git a/classification/train.py b/classification/train.py index cfb88ad..b7dbeca 100644 --- a/classification/train.py +++ b/classification/train.py @@ -1,3 +1,4 @@ +import os from resnet import model from cifar_data import CIFAR10 @@ -5,6 +6,8 @@ import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler +import torch.backends.cudnn as cudnn + import torchvision import torchvision.transforms as transforms @@ -31,8 +34,8 @@ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) -dataset_train = CIFAR10(root='./data', train=True, download=True, transform=transforms_train) -dataset_test = CIFAR10(root='./data', train=False, download=True, transform=transforms_test) +dataset_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_train) +dataset_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_test) train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_worker) test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=100, @@ -46,7 +49,7 @@ net = model.resnet18() net = net.to(device) if device == 'cuda': - net = cuda.nn.DataParallel(net) + net = torch.nn.DataParallel(net) cudnn.benchmark = True @@ -126,4 +129,4 @@ def test(epoch, best_acc): best_acc = 0 for epoch in range(200): train(epoch) - best_acc = test(epoch, best_acc) \ No newline at end of file + best_acc = test(epoch, best_acc)