Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enh/pl 1.x #34

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
git+https://github.com/hdmf-dev/hdmf.git@0efea00ff1d10c4021c8145e917d92566f7edb4c
seaborn==0.11.0
git+https://github.com/ajtritt/pytorch-lightning.git@99d2503373fe1b966cf7014c4ce7e7183766d48a
torch==1.6.0
git+https://github.com/ajtritt/pytorch-lightning.git@fb30942d2c47a95531e063ed35a22f8fba25be12
48 changes: 48 additions & 0 deletions src/exabiome/nn/loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pytorch_lightning as pl
import torch.nn.functional as F
import torch
Expand Down Expand Up @@ -94,6 +95,9 @@ def dataset_stats(argv=None):


def read_dataset(path):
for root, dirs, files in os.walk("/mnt/bb/ajtritt/"):
for filename in files:
print(rank, '-', filename)
hdmfio = get_hdf5io(path, 'r')
difile = hdmfio.read()
dataset = SeqDataset(difile)
Expand Down Expand Up @@ -392,6 +396,49 @@ def get_loader(dataset, distances=False, **kwargs):
return DataLoader(dataset, collate_fn=collater, **kwargs)


<<<<<<< HEAD
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be removed.

class DIDataModule(pl.LightningDataModule):

def __init__(self, hparams, inference=False):
self.hparams = hparams
self.inference = inference

def train_dataloader(self):
self._check_loaders()
return self.loaders['train']

def val_dataloader(self):
self._check_loaders()
return self.loaders['validate']

def test_dataloader(self):
self._check_loaders()
return self.loaders['test']



def _check_loaders(self):
"""
Load dataset if it has not been loaded yet
"""
dataset, io = process_dataset(self.hparams, inference=self._inference)
if self.hparams.load:
dataset.load()

kwargs = dict(random_state=self.hparams.seed,
batch_size=self.hparams.batch_size,
distances=self.hparams.manifold,
downsample=self.hparams.downsample)
kwargs.update(self.hparams.loader_kwargs)
if self._inference:
kwargs['distances'] = False
kwargs.pop('num_workers', None)
kwargs.pop('multiprocessing_context', None)
tr, te, va = train_test_loaders(dataset, **kwargs)
self.loaders = {'train': tr, 'test': te, 'validate': va}
self.dataset = dataset

=======
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line needs to be removed

class DeepIndexDataModule(pl.LightningDataModule):

def __init__(self, hparams, inference=False):
Expand Down Expand Up @@ -419,3 +466,4 @@ def val_dataloader(self):

def test_dataloader(self):
return self.loaders['test']
>>>>>>> master
190 changes: 116 additions & 74 deletions src/exabiome/nn/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ def read_outputs(path):
if 'viz_emb' in f:
ret['viz_emb'] = f['viz_emb'][:]
ret['labels'] = f['labels'][:]

# we won't have these three if we are looking
# at non-representatives
if 'train' in f:
ret['train_mask'] = f['train'][:]
if 'test' in f:
ret['test_mask'] = f['test'][:]
ret['outputs'] = f['outputs'][:]
if 'validate' in f:
ret['validate_mask'] = f['validate'][:]

ret['orig_lens'] = f['orig_lens'][:]
if 'seq_ids' in f:
ret['seq_ids'] = f['seq_ids'][:]
Expand Down Expand Up @@ -66,73 +70,78 @@ def plot_results(path, tvt=True, pred=True, fig_height=7, logger=None, name=None
labels = path['labels']
outputs = path['outputs']

viz_emb = None
if 'viz_emb' in path:
logger.info('found viz_emb')
viz_emb = path['viz_emb']
# else:
# logger.info('calculating UMAP embeddings for visualization')
# from umap import UMAP
# umap = UMAP(n_components=2)
# viz_emb = umap.fit_transform(outputs)
else:
logger.info('calculating UMAP embeddings for visualization')
from umap import UMAP
umap = UMAP(n_components=2)
viz_emb = umap.fit_transform(outputs)
n_plots = 1

color_labels = getattr(pred, 'classes_', None)
if color_labels is None:
color_labels = labels
class_pal = get_color_markers(len(np.unique(color_labels)))
colors = np.array([class_pal[i] for i in color_labels])

# set up figure
fig_height = 7
plt.figure(figsize=(n_plots*fig_height, fig_height))

logger.info('plotting embeddings with species labels')
# plot embeddings
ax = plt.subplot(1, n_plots, plot_count)
plot_seq_emb(viz_emb, labels, ax, pal=class_pal)
if name is not None:
plt.title(name)
plot_count += 1

# plot train/validation/testing data
train_mask = None
test_mask = None
validate_mask = None
if tvt:
logger.info('plotting embeddings train/validation/test labels')
train_mask = path['train_mask']
test_mask = path['test_mask']
validate_mask = path['validate_mask']
pal = ['gray', 'red', 'yellow']
plt.subplot(1, n_plots, plot_count)
dsubs = ['train', 'validation', 'test'] # data subsets
dsub_handles = list()
for (mask, dsub, col) in zip([train_mask, validate_mask, test_mask], dsubs, pal):
plt.scatter(viz_emb[mask, 0], viz_emb[mask, 1], s=0.1, c=[col], label=dsub)
dsub_handles.append(Circle(0, 0, color=col))
plt.legend(dsub_handles, dsubs)
if viz_emb:
logger.info('plotting embeddings with species labels')
# plot embeddings
ax = plt.subplot(1, n_plots, plot_count)
plot_seq_emb(viz_emb, labels, ax, pal=class_pal)
if name is not None:
plt.title(name)
plot_count += 1

# plot train/validation/testing data
train_mask = None
test_mask = None
validate_mask = None
if tvt:
logger.info('plotting embeddings train/validation/test labels')
train_mask = path['train_mask']
test_mask = path['test_mask']
validate_mask = path['validate_mask']
pal = ['gray', 'red', 'yellow']
plt.subplot(1, n_plots, plot_count)
dsubs = ['train', 'validation', 'test'] # data subsets
dsub_handles = list()
for (mask, dsub, col) in zip([train_mask, validate_mask, test_mask], dsubs, pal):
plt.scatter(viz_emb[mask, 0], viz_emb[mask, 1], s=0.1, c=[col], label=dsub)
dsub_handles.append(Circle(0, 0, color=col))
plt.legend(dsub_handles, dsubs)
plot_count += 1

# run some predictions and plot report
if pred is not False:
if pred is None or pred is True:
logger.info('No classifier given, using RandomForestClassifier(n_estimators=30)')
pred = RandomForestClassifier(n_estimators=30)
elif not (hasattr(pred, 'fit') and hasattr(pred, 'predict')):
raise ValueError("argument 'pred' must be a classifier with an SKLearn interface")

X_test = outputs
y_pred = pred
y_test = labels
if not hasattr(pred, 'classes_'):
train_mask = path['train_mask']
test_mask = path['test_mask']
X_train = outputs[train_mask]
y_train = labels[train_mask]
logger.info(f'training classifier {pred}')
pred.fit(X_train, y_train)
X_test = outputs[test_mask]
y_test = labels[test_mask]
logger.info(f'getting predictions')
y_pred = pred.predict(X_test)
if not isinstance(pred, (np.ndarray, list)):
if pred is None or pred is True:
logger.info('No classifier given, using RandomForestClassifier(n_estimators=30)')
pred = RandomForestClassifier(n_estimators=30)
elif not (hasattr(pred, 'fit') and hasattr(pred, 'predict')):
raise ValueError("argument 'pred' must be a classifier with an SKLearn interface")

X_test = outputs
if not hasattr(pred, 'classes_'):
train_mask = path['train_mask']
test_mask = path['test_mask']
X_train = outputs[train_mask]
y_train = labels[train_mask]
logger.info(f'training classifier {pred}')
pred.fit(X_train, y_train)
X_test = outputs[test_mask]
y_test = labels[test_mask]
logger.info(f'getting predictions')
y_pred = pred.predict(X_test)

logger.info(f'plotting classification report')
# plot classification report
Expand All @@ -156,15 +165,15 @@ def aggregated_chunk_analysis(path, clf, fig_height=7):
viz_emb = None
if 'viz_emb' in path:
viz_emb = path['viz_emb']
else:
viz_emb = UMAP(n_components=2).fit_transform(X)

uniq_seqs = np.unique(seq_ids)
X_mean = np.zeros((uniq_seqs.shape[0], outputs.shape[1]))
X_median = np.zeros((uniq_seqs.shape[0], outputs.shape[1]))
y = np.zeros(uniq_seqs.shape[0], dtype=int)
seq_len = np.zeros(uniq_seqs.shape[0], dtype=int)
seq_viz = np.zeros((uniq_seqs.shape[0], 2))
seq_viz = None
if viz_emb is not None:
seq_viz = np.zeros((uniq_seqs.shape[0], 2))

for seq_i, seq in enumerate(uniq_seqs):
seq_mask = seq_ids == seq
Expand All @@ -174,17 +183,31 @@ def aggregated_chunk_analysis(path, clf, fig_height=7):
y[seq_i] = uniq_labels[0]
X_mean[seq_i] = outputs[seq_mask].mean(axis=0)
X_median[seq_i] = np.median(outputs[seq_mask], axis=0)
seq_viz[seq_i] = viz_emb[seq_mask].mean(axis=0)
if seq_viz is not None:
seq_viz[seq_i] = viz_emb[seq_mask].mean(axis=0)
seq_len[seq_i] = olens[seq_mask].sum()

seq_len = np.log10(seq_len)

color_labels = getattr(clf, 'classes_', None)
if color_labels is None:
color_labels = labels
class_pal = get_color_markers(len(np.unique(color_labels)))
fig, axes = None, None
figsize_factor = 7
class_pal = None
if isinstance(clf, (list, np.ndarray)):
nrows = 2
ncols = 1
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(nrows*figsize_factor, ncols*figsize_factor))
axes = np.expand_dims(axes, axis=1)
all_preds = np.argmax(outputs, axis=1)
class_pal = get_color_markers(outputs.shape[1])
else:
color_labels = getattr(clf, 'classes_', None)
if color_labels is None:
color_labels = labels
class_pal = get_color_markers(len(np.unique(color_labels)))

fig, axes = plt.subplots(nrows=3, ncols=3, sharey='row', figsize=(21, 21))
nrows = 3 if seq_viz is not None else 2
ncols = 3
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharey='row', figsize=(nrows*figsize_factor, ncols*figsize_factor))

# classifier from MEAN of outputs
output_mean_preds = clf.predict(X_mean)
Expand All @@ -194,9 +217,10 @@ def aggregated_chunk_analysis(path, clf, fig_height=7):
output_median_preds = clf.predict(X_median)
make_plots(y, output_median_preds, axes[:,1], class_pal, seq_len, 'Median classification', seq_viz)

# classifier from voting with chunk predictions
all_preds = clf.predict(outputs)
vote_preds = np.zeros_like(output_mean_preds)
# classifier from voting with chunk predictions
all_preds = clf.predict(outputs)

vote_preds = np.zeros(X_mean.shape[0], dtype=int)
for seq_i, seq in enumerate(uniq_seqs):
seq_mask = seq_ids == seq
vote_preds[seq_i] = stats.mode(all_preds[seq_mask])[0][0]
Expand Down Expand Up @@ -387,6 +411,9 @@ def summarize(argv=None):
parser.add_argument('-A', '--aggregate_chunks', action='store_true', default=False,
help='aggregate chunks within sequences and perform analysis')
parser.add_argument('-o', '--outdir', type=str, default=None, help='the output directory for figures')
type_group = parser.add_argument_group('Problem type').add_mutually_exclusive_group()
type_group.add_argument('-C', '--classify', action='store_true', help='run a classification problem', default=False)
type_group.add_argument('-M', '--manifold', action='store_true', help='run a manifold learning problem', default=False)

args = parser.parse_args(args=argv)
if os.path.isdir(args.input):
Expand All @@ -405,23 +432,38 @@ def summarize(argv=None):
fig_path = os.path.join(outdir, 'summary.png')
logger = parse_logger('')

plt.figure(figsize=(21, 7))
pretrained = False
if args.classifier is not None:
with open(args.classifier, 'rb') as f:
pred = pickle.load(f)
pretrained = True
else:
pred = RandomForestClassifier(n_estimators=30)
outputs = read_outputs(args.input)
pred = plot_results(outputs, pred=pred, name='/'.join(args.input.split('/',)[-2:]), logger=logger)
if args.classify:
plt.figure(figsize=(7, 7))
labels = outputs['labels']
model_outputs = outputs['outputs']
if 'test_mask' in outputs:
mask = outputs['test_mask']
labels = labels[mask]
model_outputs = model_outputs[mask]

pred = np.argmax(model_outputs, axis=1)
class_pal = get_color_markers(model_outputs.shape[1])
colors = np.array([class_pal[i] for i in labels])
ax = plt.gca()
plot_clf_report(labels, pred, ax=ax, pal=class_pal)
else:
plt.figure(figsize=(21, 7))
pretrained = False
if args.classifier is not None:
with open(args.classifier, 'rb') as f:
pred = pickle.load(f)
pretrained = True
else:
pred = RandomForestClassifier(n_estimators=30)
pred = plot_results(outputs, pred=pred, name='/'.join(args.input.split('/',)[-2:]), logger=logger)
if not pretrained:
clf_path = os.path.join(outdir, 'summary.rf.pkl')
logger.info(f'saving classifier to {clf_path}')
with open(clf_path, 'wb') as f:
pickle.dump(pred, f)
logger.info(f'saving figure to {fig_path}')
plt.savefig(fig_path, dpi=100)
if not pretrained:
clf_path = os.path.join(outdir, 'summary.rf.pkl')
logger.info(f'saving classifier to {clf_path}')
with open(clf_path, 'wb') as f:
pickle.dump(pred, f)

if args.aggregate_chunks:
logger.info(f'running summary by aggregating chunks within sequences')
Expand Down
Loading