Skip to content

codingfisch/flashrl

Repository files navigation

flashrl

flashrl does RL with millions of steps/second 💨 while being tiny: ~200 lines of code

🛠️ pip install flashrl or clone the repo & pip install -r requirements.txt

  • If cloned (or if envs changed), compile: python setup.py build_ext --inplace

💡 flashrl will always be tiny: Read the code (+paste into LLM) to understand it!

Quick Start 🚀

flashrl uses a Learner that holds an env and a model (default: Policy with LSTM)

import flashrl as frl

learn = frl.Learner(frl.envs.Pong(n_agents=2**14))
curves = learn.fit(40, steps=16, desc='done')
frl.print_curve(curves['loss'], label='loss')
frl.play(learn.env, learn.model, fps=8)
learn.env.close()

.fit does RL with ~10 million steps: 40 iterations × 16 steps × 2**14 agents!

Run it yourself via python train.py and play against the AI 🪄

Click here, to read a tiny doc 📑

Learner takes the arguments

  • env: RL environment
  • model: A Policy model
  • device: Per default picks mps or cuda if available else cpu
  • dtype: Per default torch.bfloat16 if device is cuda else torch.float32
  • compile_no_lstm: Speedup via torch.compile if model has no lstm
  • **kwargs: Passed to the Policy, e.g. hidden_size or lstm

Learner.fit takes the arguments

  • iters: Number of iterations
  • steps: Number of steps in rollout
  • desc: Progress bar description (e.g. 'reward')
  • log: If True, tensorboard logging is enabled
    • run tensorboard --logdir=runsand visit http://localhost:6006 in the browser!
  • stop_func: Function that stops training if it returns True e.g.
...
def stop(kl, **kwargs):
  return kl > .1

curves = learn.fit(40, steps=16, stop_func=stop)
...
  • lr, anneal_lr & args of ppo after bs: Hyperparameters

The most important functions in flashrl/utils.py are

  • print_curve: Visualizes the loss across the iters
  • play: Plays the environment in the terminal and takes
    • model: A Policy model
    • playable: If True, allows you to act (or decide to let the model act)
    • steps: Number of steps
    • fps: Frames per second
    • obs: Argument of the env that should be rendered as observations
    • dump: If True, no frame refresh -> Frames accumulate in the terminal
    • idx: Agent index between 0 and n_agents (default: 0)

Environments 🕹️

Each env is one Cython(=.pyx) file in flashrl/envs. That's it!

To add custom envs, use grid.pyx, pong.pyx or multigrid.pyx as a template:

  • grid.pyx for single-agent envs (~110 LOC)
  • pong.pyx for 1 vs 1 agent envs (~150 LOC)
  • multigrid.pyx for multi-agent envs (~190 LOC)
Grid Pong MultiGrid
Agent must reach goal Agent must score Agent must reach goal first
grid pong multigrid

Acknowledgements 🙌

I want to thank

and last but not least...