-
Notifications
You must be signed in to change notification settings - Fork 10
mocodad
Bases: LightningModule
Build the model according to the specified hyperparameters. If the conditioning strategy is ‘inject’, the conditioning network is built and the available architectures are: AutoEncoder (AE), Encoder (E), Encoder-UNet (E_unet). For the other conditioning strategies, the conditioning network is set to None.
- Raises: NotImplementedError – if the conditioning architecture is not implemented
conditioning_strategies = {'add2layers': 'inject', 'cat': 'concat', 'concat': 'concat', 'inbetween_imp': 'interleave', 'inject': 'inject', 'interleave': 'interleave', 'no_condition': 'no_condition', 'none': 'no_condition', 'random_imp': 'random_imp', 'random_indices': 'random_imp'}
Configure the optimizers and the learning rate schedulers.
- Returns: dictionary containing the optimizers, the learning rate schedulers and the metric to monitor
- Return type: Dict
Forward pass of the model.
-
Parameters:
-
input_data (List*[torch.Tensor]*) – list containing the following tensors:
- tensor_data: tensor of shape (B, C, T, V) containing the input sequences
- transformation_idx
- metadata
- actual_frames
- aggr_strategy (str*,* optional) – aggregation strategy to use. If not specified as a function parameter, the aggregation strategy specified in the model hyperparameters is used. Defaults to None.
-
return (str*,* optional) – return value of the model:
- only the selected poses according to the aggregation strategy (‘pose’)
- only the loss of the selected poses (‘loss’)
- both (‘all’). If not specified as a function parameter, the return value specified in the model hyperparameters is used. Defaults to None.
-
input_data (List*[torch.Tensor]*) – list containing the following tensors:
- Returns: [predicted poses and the loss, tensor_data, transformation_idx, metadata, actual_frames]
- Return type: List[torch.Tensor]
losses = {'l1': <class 'torch.nn.modules.loss.L1Loss'>, 'l2': <class 'torch.nn.modules.loss.MSELoss'>, 'smooth_l1': <class 'torch.nn.modules.loss.SmoothL1Loss'>}
Test epoch end of the model.
- Returns: test auc score
- Return type: float
Called when the test epoch begins.
Validation epoch end of the model.
- Returns: validation auc score
- Return type: float
Called when the test epoch begins.
Post processing of the model.
-
Parameters:
- out (np.ndarray) – output of the model
- gt_data (np.ndarray) – ground truth data
- trans (np.ndarray) – transformation index
- meta (np.ndarray) – metadata
- frames (np.ndarray) – frame indexes of the data
- Returns: auc score
- Return type: float
Skip the prediction step and test the model on the saved tensors.
- Parameters: split_name (str) – split name (val, test)
- Returns: auc score
- Return type: float
Test step of the model. It saves the output of the model and the input data as List[torch.Tensor]: [predicted poses and the loss, tensor_data, transformation_idx, metadata, actual_frames]
-
Parameters:
-
batch (List*[torch.Tensor]*) – list containing the following tensors:
- tensor_data: tensor of shape (B, C, T, V) containing the input sequences
- transformation_idx
- metadata
- actual_frames
- batch_idx (int) – index of the batch
-
batch (List*[torch.Tensor]*) – list containing the following tensors:
Training step of the model.
-
Parameters:
-
batch (List*[torch.Tensor]*) – list containing the following tensors:
- tensor_data: tensor of shape (B, C, T, V) containing the input sequences
- transformation_idx
- metadata
- actual_frames
- batch_idx (int) – index of the batch
-
batch (List*[torch.Tensor]*) – list containing the following tensors:
- Returns: loss of the model
- Return type: torch.float32
Validation step of the model. It saves the output of the model and the input data as List[torch.Tensor]: [predicted poses and the loss, tensor_data, transformation_idx, metadata, actual_frames]
-
Parameters:
-
batch (List*[torch.Tensor]*) – list containing the following tensors:
- tensor_data: tensor of shape (B, C, T, V) containing the input sequences
- transformation_idx
- metadata
- actual_frames
- batch_idx (int) – index of the batch
-
batch (List*[torch.Tensor]*) – list containing the following tensors: