Skip to content

Commit

Permalink
add: dataset loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGarofalo94 committed Jun 20, 2024
1 parent a38f5fb commit 97e00f7
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 0 deletions.
85 changes: 85 additions & 0 deletions datasets/loaders/cifar10_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np
import os

def load_dataset(data_dir):
train_dataset = load_train_dataset(data_dir)
test_dataset = load_test_dataset(data_dir)
return { "train": train_dataset, "test": test_dataset}

def split_dataset(dataset, distribution, dataset_length, chunk_size):
if distribution == "iid":
batches = split_chunks_iid(dataset, dataset_length, chunk_size)
elif distribution == "non-iid":
batches = split_chunks_non_iid(dataset, dataset_length, chunk_size)
return batches

#####################################################################

def load_train_dataset(data_dir):
data_batches = []
for i in range(1, 6):
batch_file = os.path.join(data_dir, f'data_batch_{i}.bin')
with open(batch_file, 'rb') as f:
data_batch = np.fromfile(f, dtype=np.uint8)
data_batches.append(data_batch)
combined_data = combine_batches(data_batches)
return combined_data

def load_test_dataset(data_dir):
data_batches = []
batch_file = os.path.join(data_dir, 'test_batch.bin')
with open(batch_file, 'rb') as f:
data_batch = np.fromfile(f, dtype=np.uint8)
data_batches.append(data_batch)
combined_data = combine_batches(data_batches)
return combined_data

def combine_batches(data_batches):
combined_data = {'data': np.concatenate([batch.reshape(-1, 3073) for batch in data_batches], axis=0)}
combined_data['labels'] = combined_data['data'][:, 0]
combined_data['images'] = combined_data['data'][:, 1:]
del combined_data['data']
return combined_data

def split_chunks_iid(data, dataset_length, batch_size):
num_batches = dataset_length // batch_size
samples_per_batch = dataset_length // num_batches
batches = []
for i in range(num_batches):
start_idx = i * samples_per_batch
end_idx = (i + 1) * samples_per_batch if i < num_batches - 1 else dataset_length
batch_data = {
'images': data['images'][start_idx:end_idx],
'labels': data['labels'][start_idx:end_idx]
}
batches.append(batch_data)
return batches

def split_chunks_non_iid(data, dataset_length, batch_size):
num_batches = dataset_length // batch_size
num_bytes = len(data['images'])
bytes_per_sample = num_bytes // dataset_length
samples = []
for i in range(dataset_length):
start_idx = i * bytes_per_sample
end_idx = (i + 1) * bytes_per_sample if i < dataset_length - 1 else dataset_length
batch_data = {
'images': data['images'][start_idx:end_idx],
'labels': data['labels'][start_idx:end_idx]
}
samples.append(batch_data)

batches_sorted = sorted(samples, key=lambda x: x['labels'])

batches = []
length_batch = dataset_length // num_batches
for i in range(num_batches):
start_idx = i * length_batch
end_idx = (i + 1) * length_batch if i < num_batches - 1 else dataset_length
batch = {
'images': np.concatenate([b['images'] for b in batches_sorted[start_idx:end_idx]], axis=0),
'labels': np.concatenate([b['labels'] for b in batches_sorted[start_idx:end_idx]], axis=0)
}
batches.append(batch)

return batches
100 changes: 100 additions & 0 deletions datasets/loaders/mnist_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import numpy as np
import os

def load_dataset(data_dir):
train_dataset = load_train_dataset(data_dir)
test_dataset = load_test_dataset(data_dir)
return { "train": train_dataset, "test": test_dataset}

def split_dataset(dataset, distribution, dataset_length, chunk_size):
images_data = dataset["data"][0]
labels_data = dataset["data"][1]

if distribution == "iid":
chunks_images, chunks_labels = split_chunks_iid(images_data, labels_data, dataset_length, chunk_size)
elif distribution == "non-iid":
chunks_images, chunks_labels = split_chunks_non_iid(images_data, labels_data, dataset_length, chunk_size)

batches = []
for i in range(len(chunks_images)):
batches.append({"images": chunks_images[i], "labels": chunks_labels[i]})
return batches


#####################################################################

def load_train_dataset(data_dir):
images_file = os.path.join(data_dir, 'train-images-idx3-ubyte')
labels_file = os.path.join(data_dir, 'train-labels-idx1-ubyte')

with open(images_file, 'rb') as f:
f.read(16)
images_data = np.fromfile(f, dtype=np.uint8).reshape(-1, 28*28)
with open(labels_file, 'rb') as f:
f.read(8)
labels_data = np.fromfile(f, dtype=np.uint8)
return {"data": [images_data, labels_data]}

def load_test_dataset(data_dir):
images_file = os.path.join(data_dir, 't10k-images-idx3-ubyte')
labels_file = os.path.join(data_dir, 't10k-labels-idx1-ubyte')

with open(images_file, 'rb') as f:
f.read(16)
images_data = np.fromfile(f, dtype=np.uint8).reshape(-1, 28*28)

with open(labels_file, 'rb') as f:
f.read(8)
labels_data = np.fromfile(f, dtype=np.uint8)
return {"images": images_data, "labels": labels_data}


def split_chunks_iid(images_data, labels_data, dataset_length, samples_per_batch):
num_samples = len(images_data)
num_batches = num_samples // samples_per_batch
batches_images = []
batches_labels = []
for i in range(num_batches):
start_idx = i * samples_per_batch
end_idx = (i + 1) * samples_per_batch if i < num_batches - 1 else num_samples
batch_images = images_data[start_idx:end_idx]
batch_labels = labels_data[start_idx:end_idx]
batches_images.append(batch_images)
batches_labels.append(batch_labels)
return batches_images, batches_labels

def split_chunks_non_iid(images_data, labels_data, length_dataset, samples_per_batch):
total_bytes_images = len(images_data)
sample_size_images = total_bytes_images // length_dataset
total_bytes_labels = len(labels_data)
sample_size_labels = total_bytes_labels // length_dataset
batches_images = []
batches_labels = []

for i in range(length_dataset):
start_idx = i * sample_size_images
end_idx = (i + 1) * sample_size_images if i < length_dataset - 1 else total_bytes_images
batch_image = images_data[start_idx:end_idx]

start_idx = i * sample_size_labels
end_idx = (i + 1) * sample_size_labels if i < length_dataset - 1 else total_bytes_labels
batch_label = labels_data[start_idx:end_idx]

batches_images.append(batch_image)
batches_labels.append(batch_label)

# sort samples per label
sorted_indices = sorted(range(len(batches_labels)), key=lambda k: batches_labels[k])
batch_images = [batches_images[idx] for idx in sorted_indices]
batch_labels = [batches_labels[idx] for idx in sorted_indices]

# split into num_batches
num_samples = len(images_data)
num_batches = num_samples // samples_per_batch
batches_images = []
batches_labels = []
for i in range(num_batches):
batches_images.append(np.concatenate(batch_images[i*samples_per_batch:(i+1)*samples_per_batch], axis=0))
batches_labels.append(np.concatenate(batch_labels[i*samples_per_batch:(i+1)*samples_per_batch], axis=0))

return batches_images, batches_labels

0 comments on commit 97e00f7

Please sign in to comment.