-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdlc_bci.py
75 lines (52 loc) · 2.05 KB
/
dlc_bci.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# This is distributed under BSD 3-Clause license
import torch
import numpy
import os
import errno
from six.moves import urllib
def tensor_from_file(root, filename,
base_url = 'https://documents.epfl.ch/users/f/fl/fleuret/www/data/bci'):
file_path = os.path.join(root, filename)
if not os.path.exists(file_path):
try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
url = base_url + '/' + filename
print('Downloading ' + url)
data = urllib.request.urlopen(url)
with open(file_path, 'wb') as f:
f.write(data.read())
return torch.from_numpy(numpy.loadtxt(file_path))
def load(root, train = True, download = True, one_khz = False):
"""
Args:
root (string): Root directory of dataset.
train (bool, optional): If True, creates dataset from training data.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
one_khz (bool, optional): If True, creates dataset from the 1000Hz data instead
of the default 100Hz.
"""
nb_electrodes = 28
if train:
if one_khz:
dataset = tensor_from_file(root, 'sp1s_aa_train_1000Hz.txt')
else:
dataset = tensor_from_file(root, 'sp1s_aa_train.txt')
input = dataset.narrow(1, 1, dataset.size(1) - 1)
input = input.float().view(input.size(0), nb_electrodes, -1)
target = dataset.narrow(1, 0, 1).clone().view(-1).long()
else:
if one_khz:
input = tensor_from_file(root, 'sp1s_aa_test_1000Hz.txt')
else:
input = tensor_from_file(root, 'sp1s_aa_test.txt')
target = tensor_from_file(root, 'labels_data_set_iv.txt')
input = input.float().view(input.size(0), nb_electrodes, -1)
target = target.view(-1).long()
return input, target