Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong length of test dataset #78

Open
LPY1219 opened this issue Jan 16, 2025 · 1 comment
Open

Wrong length of test dataset #78

LPY1219 opened this issue Jan 16, 2025 · 1 comment

Comments

@LPY1219
Copy link

LPY1219 commented Jan 16, 2025

Hi,

I have set the sample_mode='enumerate', with a batch size of 48. The two tasks are close_jar and insert_onto_square_peg, and the returned length of the test_dataset is 64. I believe this means that if I iterate 64 times, all samples in the test replay buffer will be enumerated. However, when I do that, I find that the insert_onto_square_peg task is never sampled. Only after increasing the number of iterations to about 263 do I begin to see the insert_onto_square_peg samples.

Do you have any insight into why this is happening?

`class PyTorchIterableReplayDataset(IterableDataset):

def __init__(self, replay_buffer: ReplayBuffer, sample_mode, sample_distribution_mode = 'transition_uniform'):
    self._replay_buffer = replay_buffer
    self._sample_mode = sample_mode
    if self._sample_mode == 'enumerate':
        self._num_samples = self._replay_buffer.prepare_enumeration()
    self._sample_distribution_mode = sample_distribution_mode

def _generator(self):
    while True:
        if self._sample_mode == 'random':
            yield self._replay_buffer.sample_transition_batch(pack_in_dict=True, distribution_mode = self._sample_distribution_mode)
        elif self._sample_mode == 'enumerate':
            yield self._replay_buffer.enumerate_next_transition_batch(pack_in_dict=True)

def __iter__(self):
    return iter(self._generator())

def __len__(self): # enumeration will throw away the last incomplete batch
    return self._num_samples // self._replay_buffer._batch_size`
@LPY1219
Copy link
Author

LPY1219 commented Jan 16, 2025

Also, when i try to add validation set, i find the programm will be blocked and can not proceed during training:

print("Start training ...", flush=True)
 i = start_epoch
 while True:
     if i == end_epoch:
         break

     print(f"Rank [{rank}], Epoch [{i}]: Training on train dataset")
     out = train(agent, train_dataset, TRAINING_ITERATIONS, rank)
     out_val= validate(agent, val_dataset, VAL_ITERATIONS, rank,)
     
     if rank == 0:
         tb.update("train", i, out)
         tb.update("val", i, out_val)

     if rank == 0 and i % 5 == 0:
         # TODO: add logic to only save some models
         save_agent(agent, f"{log_dir}/model_{i}.pth", i)
         save_agent(agent, f"{log_dir}/model_last.pth", i)
     i += 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant