From 0af7099e204546461b7668a7300903f0e6fb2fed Mon Sep 17 00:00:00 2001 From: Holger Roth <6304754+holgerroth@users.noreply.github.com> Date: Mon, 25 Nov 2024 14:42:45 -0500 Subject: [PATCH] Add no_grad to validation steps (#3071) * add no_grad to validation steps * small edit --- .../src/splitnn/cifar10_learner_splitnn.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/examples/advanced/vertical_federated_learning/cifar10-splitnn/src/splitnn/cifar10_learner_splitnn.py b/examples/advanced/vertical_federated_learning/cifar10-splitnn/src/splitnn/cifar10_learner_splitnn.py index 90ab2fe48a..1e394590c2 100644 --- a/examples/advanced/vertical_federated_learning/cifar10-splitnn/src/splitnn/cifar10_learner_splitnn.py +++ b/examples/advanced/vertical_federated_learning/cifar10-splitnn/src/splitnn/cifar10_learner_splitnn.py @@ -240,11 +240,11 @@ def _train_step_data_side(self, batch_indices): def _val_step_data_side(self, batch_indices): t_start = timer() self.model.eval() + with torch.no_grad(): + inputs = self.valid_dataset.get_batch(batch_indices) + inputs = inputs.to(self.device) - inputs = self.valid_dataset.get_batch(batch_indices) - inputs = inputs.to(self.device) - - _val_activations = self.model.forward(inputs) # keep on site-1 + _val_activations = self.model.forward(inputs) # keep on site-1 self.compute_stats_pool.record_value(category="_val_step_data_side", value=timer() - t_start) @@ -295,23 +295,24 @@ def _train_step_label_side(self, batch_indices, activations, fl_ctx: FLContext): def _val_step_label_side(self, batch_indices, activations, fl_ctx: FLContext): t_start = timer() self.model.eval() + with torch.no_grad(): + labels = self.valid_dataset.get_batch(batch_indices) + labels = labels.to(self.device) - labels = self.valid_dataset.get_batch(batch_indices) - labels = labels.to(self.device) + if self.fp16: + activations = activations.type(torch.float32) # return to default pytorch precision - if self.fp16: - activations = activations.type(torch.float32) # return to default pytorch precision + activations = activations.to(self.device) - activations = activations.to(self.device) + pred = self.model.forward(activations) - pred = self.model.forward(activations) - loss = self.criterion(pred, labels) - self.val_loss.append(loss.unsqueeze(0)) # unsqueeze needed for later concatenation + loss = self.criterion(pred, labels) + self.val_loss.append(loss.unsqueeze(0)) # unsqueeze needed for later concatenation - _, pred_labels = torch.max(pred, 1) + _, pred_labels = torch.max(pred, 1) - self.val_pred_labels.extend(pred_labels.unsqueeze(0)) - self.val_labels.extend(labels.unsqueeze(0)) + self.val_pred_labels.extend(pred_labels.unsqueeze(0)) + self.val_labels.extend(labels.unsqueeze(0)) self.compute_stats_pool.record_value(category="_val_step_label_side", value=timer() - t_start)