Skip to content

Commit

Permalink
fixes to data generation process
Browse files Browse the repository at this point in the history
  • Loading branch information
ymetz committed Jul 8, 2024
1 parent efbd627 commit af9b349
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/source/guide/add_new_experiment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ The easiest way to generate the data is to use the ``generate_data.py`` script r
python -m rlhfblender.generate_data --exp MyExperiment --random
#! generate data for a pre-registered environment and create a new experiment (with a random policy)
python -m rlhfblender.generate_data --env MyEnv-v0 --exp MyNewEnvironment --random -n-episodes 10
python -m rlhfblender.generate_data --env MyEnv-v0 --exp MyNewEnvironment --random --num-episodes 10
#! generate data for a pre-registered environment and use checkpoints for inference
python -m rlhfblender.generate_data --env MyEnv-v0 --exp MyNewEnvironment --model-path path/to/checkpoints --checkpoints 100000 200000 300000
Expand Down
9 changes: 8 additions & 1 deletion rlhfblender/data_collection/environment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,18 @@ def get_environment(

vec_env_cls = DummyVecEnv

env_kwargs = environment_config.get("env_kwargs", None)
# add render_mode = 'rgb_array' to env_kwargs
if env_kwargs is None:
env_kwargs = {"render_mode": "rgb_array"}
else:
env_kwargs["render_mode"] = "rgb_array"

env = make_vec_env(
env_name,
n_envs=n_envs,
wrapper_class=env_wrapper,
env_kwargs=environment_config.get("env_kwargs", None),
env_kwargs=env_kwargs,
vec_env_cls=vec_env_cls,
vec_env_kwargs=environment_config.get("vec_env_kwargs", None),
)
Expand Down
1 change: 1 addition & 0 deletions rlhfblender/routes/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class BenchmarkRequestModel(BenchmarkModel):
reset_state: bool = False
split_by_episode: bool = False
record_episode_videos: bool = False



class VideoRequestModel(BenchmarkModel):
Expand Down
20 changes: 13 additions & 7 deletions rlhfblender/utils/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ async def run_benchmark(request: List[BenchmarkRequestModel]) -> list[Experiment
benchmarked_experiments = []
for benchmark_run in request:
print(request)
if benchmark_run.benchmark_id != "":
if benchmark_run.benchmark_id != "" and await db_handler.check_if_exists(
database, Experiment, key=benchmark_run.benchmark_id, key_column="exp_name"
):
exp: Experiment = await db_handler.get_single_entry(
database, Experiment, key=benchmark_run.benchmark_id, key_column="exp_name"
)
else:
# for the experiments, we need to register the environment first (e.g. for annotations, naming of the action space)
if not db_handler.check_if_exists(database, Environment, value=benchmark_run.env_id, column="registration_id"):
if not db_handler.check_if_exists(database, Environment, key=benchmark_run.env_id, key_column="registration_id"):
# We lazily register the environment if it is not registered yet, this is only done once
database_env = initial_registration(
benchmark_run.env_id,
Expand All @@ -83,12 +85,16 @@ async def run_benchmark(request: List[BenchmarkRequestModel]) -> list[Experiment
),
)
await db_handler.add_entry(database, Environment, database_env.model_dump())
else:
database_env = await db_handler.get_single_entry(
database, Environment, key=benchmark_run.env_id, key_column="registration_id"
)

# create and register a "dummy" experiment
exp: Experiment = Experiment(
exp_name=f"{benchmark_run.env_id}_{benchmark_run.framwork}_{benchmark_run.benchmark_type}_Experiment",
exp_name=f"{benchmark_run.env_id}_{benchmark_run.benchmark_type}_Experiment",
env_id=benchmark_run.env_id,
framework=benchmark_run.framework,
framework="Random",
created_timestamp=int(time.time()),
)
await db_handler.add_entry(database, Experiment, exp)
Expand All @@ -112,7 +118,7 @@ async def run_benchmark(request: List[BenchmarkRequestModel]) -> list[Experiment
n_envs=1,
norm_env_path=os.path.join(benchmark_run.path, benchmark_run.env_id),
# this is how SB-Zoo does it, so we stick to it for easy cross-compatabily
additional_packages=benchmark_run.additional_packages if "additional_packages" in benchmark_run else [],
additional_packages=database_env.additional_gym_packages,
)
if "BabyAI" not in benchmark_run.env_id
else gym.make(benchmark_run.env_id, render_mode="rgb_array")
Expand Down Expand Up @@ -157,11 +163,11 @@ async def run_benchmark(request: List[BenchmarkRequestModel]) -> list[Experiment
# Now, create the video/thumbnail/reward data etc.
def split_data(data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""Splits the data into episodes."""
episode_ends = np.argwhere(data["dones"]).squeeze()
episode_ends = np.argwhere(data["dones"])
episodes = {}
for name, data_item in data.items():
if data_item.shape:
episodes[name] = np.split(data_item, episode_ends)
episodes[name] = np.split(data_item, episode_ends.flatten() + 1)
else:
episodes[name] = data_item

Expand Down

0 comments on commit af9b349

Please sign in to comment.