flashrl
does RL with millions of steps/second 💨 while being tiny: ~200 lines of code
🛠️ pip install flashrl
or clone the repo & pip install -r requirements.txt
- If cloned (or if envs changed), compile:
python setup.py build_ext --inplace
💡 flashrl
will always be tiny: Read the code (+paste into LLM) to understand it!
flashrl
uses a Learner
that holds an env
and a model
(default: Policy
with LSTM)
import flashrl as frl
learn = frl.Learner(frl.envs.Pong(n_agents=2**14))
curves = learn.fit(40, steps=16, desc='done')
frl.print_curve(curves['loss'], label='loss')
frl.play(learn.env, learn.model, fps=8)
learn.env.close()
.fit
does RL with ~10 million steps: 40
iterations × 16
steps × 2**14
agents!
Run it yourself via python train.py
and play against the AI 🪄
Click here, to read a tiny doc 📑
Learner
takes the arguments
env
: RL environmentmodel
: APolicy
modeldevice
: Per default picksmps
orcuda
if available elsecpu
dtype
: Per defaulttorch.bfloat16
if device iscuda
elsetorch.float32
compile_no_lstm
: Speedup viatorch.compile
ifmodel
has nolstm
**kwargs
: Passed to thePolicy
, e.g.hidden_size
orlstm
Learner.fit
takes the arguments
iters
: Number of iterationssteps
: Number of steps inrollout
desc
: Progress bar description (e.g.'reward'
)log
: IfTrue
,tensorboard
logging is enabled- run
tensorboard --logdir=runs
and visithttp://localhost:6006
in the browser!
- run
stop_func
: Function that stops training if it returnsTrue
e.g.
...
def stop(kl, **kwargs):
return kl > .1
curves = learn.fit(40, steps=16, stop_func=stop)
...
lr
,anneal_lr
& args ofppo
afterbs
: Hyperparameters
The most important functions in flashrl/utils.py
are
print_curve
: Visualizes the loss across theiters
play
: Plays the environment in the terminal and takesmodel
: APolicy
modelplayable
: IfTrue
, allows you to act (or decide to let the model act)steps
: Number of stepsfps
: Frames per secondobs
: Argument of the env that should be rendered as observationsdump
: IfTrue
, no frame refresh -> Frames accumulate in the terminalidx
: Agent index between0
andn_agents
(default:0
)
Each env is one Cython(=.pyx
) file in flashrl/envs
. That's it!
To add custom envs, use grid.pyx
, pong.pyx
or multigrid.pyx
as a template:
grid.pyx
for single-agent envs (~110 LOC)pong.pyx
for 1 vs 1 agent envs (~150 LOC)multigrid.pyx
for multi-agent envs (~190 LOC)
Grid |
Pong |
MultiGrid |
---|---|---|
Agent must reach goal | Agent must score | Agent must reach goal first |
![]() |
![]() |
![]() |
I want to thank
- Joseph Suarez for open sourcing RL envs in C(ython)! Star PufferLib ⭐
- Costa Huang for open sourcing high-quality single-file RL code! Star cleanrl ⭐
and last but not least...