Skip to content

Commit

Permalink
Minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jsschreck committed Dec 19, 2024
1 parent e553c22 commit 81a88ed
Showing 1 changed file with 1 addition and 44 deletions.
45 changes: 1 addition & 44 deletions credit/datasets/era5_multistep_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,8 @@ def __init__(
transform=self.transform,
)

self.total_length = len(self.ERA5_indices)
# Set an epoch flag so that if set_epoch is not called, a warning will be issued
self.current_epoch = None
self.current_index = None

# Use DistributedSampler for index management
self.sampler = DistributedSampler(
Expand All @@ -275,48 +274,6 @@ def __init__(
self.time_steps = None # Tracks time steps for each batch index
self.forecast_step_counts = None # Track forecast step counts for each batch item

# Initialize batch once when the dataset is created
# self.initialize_batch()

# def initialize_batch(self):
# """
# Initializes batch indices using DistributedSampler's indices.
# Resets the time steps and forecast step counts.
# Ensures proper cycling when shuffle=False.
# """
# # Initialize the call count if not already present
# if not hasattr(self, "batch_call_count"):
# self.batch_call_count = 0

# # Set epoch for DistributedSampler to ensure consistent shuffling across devices
# if self.current_epoch is not None:
# self.sampler.set_epoch(self.current_epoch)

# # Retrieve indices for this GPU
# indices = list(self.sampler)
# total_indices = len(indices)

# # Select batch indices based on call count (deterministic cycling)
# start = self.batch_call_count * self.batch_size
# end = start + self.batch_size

# if end > total_indices:
# # Wrap-around to ensure no index is skipped
# indices = indices[start:] + indices[:(end % total_indices)]
# else:
# indices = indices[start:end]

# # Increment batch_call_count, reset when all indices are cycled
# self.batch_call_count += 1
# if start + self.batch_size >= total_indices:
# self.batch_call_count = 0 # Reset for next cycle

# # Assign batch indices
# self.batch_indices = indices
# self.time_steps = [0 for _ in self.batch_indices]
# self.forecast_step_counts = [0 for _ in self.batch_indices]
# self.initial_indices = list(self.batch_indices)

def initialize_batch(self):
"""
Initializes batch indices using DistributedSampler's indices.
Expand Down

0 comments on commit 81a88ed

Please sign in to comment.