Skip to content

Commit

Permalink
Merge branch 'main' of github.com:google/evojax into main
Browse files Browse the repository at this point in the history
  • Loading branch information
alantian committed Jul 22, 2022
2 parents 24662e0 + 81d709d commit 74d4c70
Show file tree
Hide file tree
Showing 19 changed files with 389 additions and 51 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ EvoJAX is a scalable, general purpose, hardware-accelerated [neuroevolution](htt

This repo also includes several extensible examples of EvoJAX for a wide range of tasks, including supervised learning, reinforcement learning and generative art, demonstrating how EvoJAX can run your evolution experiments within minutes on a single accelerator, compared to hours or days when using CPUs.

EvoJAX paper: https://arxiv.org/abs/2202.05008
EvoJAX paper: https://arxiv.org/abs/2202.05008 (presentation [video](https://youtu.be/TMkft3wWpb8))

Please use this BibTeX if you wish to cite this project in your publications:

Expand Down
3 changes: 3 additions & 0 deletions evojax/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .sep_cma_es import Sep_CMA_ES
from .cma_jax import CMA_ES_JAX
from .map_elites import MAPElites
from .iamalgam import iAMaLGaM

Strategies = {
"CMA": CMA,
Expand All @@ -34,6 +35,7 @@
"Sep_CMA_ES": Sep_CMA_ES,
"CMA_ES_JAX": CMA_ES_JAX,
"MAPElites": MAPElites,
"iAMaLGaM": iAMaLGaM,
}

__all__ = [
Expand All @@ -48,5 +50,6 @@
"Sep_CMA_ES",
"CMA_ES_JAX",
"MAPElites",
"iAMaLGaM",
"Strategies",
]
28 changes: 17 additions & 11 deletions evojax/algo/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,18 @@ def __init__(
)

# Set hyperparameters according to provided inputs
self.es_params = self.es.default_params
for k, v in optimizer_config.items():
self.es_params[k] = v
self.es_params["sigma_init"] = init_stdev
self.es_params["sigma_decay"] = decay_stdev
self.es_params["sigma_limit"] = limit_stdev
self.es_params["init_min"] = 0.0
self.es_params["init_max"] = 0.0
self.es_params = self.es.default_params.replace(
sigma_init=init_stdev,
sigma_decay=decay_stdev,
sigma_limit=limit_stdev,
init_min=0.0,
init_max=0.0,
)

# Update optimizer-specific parameters of Adam
self.es_params = self.es_params.replace(
opt_params=self.es_params.opt_params.replace(**optimizer_config)
)

# Initialize the evolution strategy state
self.rand_key, init_key = jax.random.split(self.rand_key)
Expand All @@ -126,9 +130,11 @@ def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None:

@property
def best_params(self) -> jnp.ndarray:
return jnp.array(self.es_state["mean"], copy=True)
return jnp.array(self.es_state.mean, copy=True)

@best_params.setter
def best_params(self, params: Union[np.ndarray, jnp.ndarray]) -> None:
self.es_state["best_member"] = jnp.array(params, copy=True)
self.es_state["mean"] = jnp.array(params, copy=True)
self.es_state = self.es_state.replace(
best_member=jnp.array(params, copy=True),
mean=jnp.array(params, copy=True),
)
11 changes: 6 additions & 5 deletions evojax/algo/cma_evosax.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def __init__(
)

# Set hyperparameters according to provided inputs
self.es_params = self.es.default_params
self.es_params["sigma_init"] = init_stdev
self.es_params = self.es.default_params.replace(sigma_init=init_stdev)

# Initialize the evolution strategy state
self.rand_key, init_key = jax.random.split(self.rand_key)
Expand All @@ -99,9 +98,11 @@ def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None:

@property
def best_params(self) -> jnp.ndarray:
return jnp.array(self.es_state["mean"], copy=True)
return jnp.array(self.es_state.mean, copy=True)

@best_params.setter
def best_params(self, params: Union[np.ndarray, jnp.ndarray]) -> None:
self.es_state["best_member"] = jnp.array(params, copy=True)
self.es_state["mean"] = jnp.array(params, copy=True)
self.es_state = self.es_state.replace(
best_member=jnp.array(params, copy=True),
mean=jnp.array(params, copy=True),
)
135 changes: 135 additions & 0 deletions evojax/algo/iamalgam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import sys

import logging
from typing import Union, Optional
import numpy as np
import jax
import jax.numpy as jnp

from evojax.algo.base import NEAlgorithm
from evojax.util import create_logger


class iAMaLGaM(NEAlgorithm):
"""A wrapper around evosax's iAMaLGaM.
Implementation: https://github.com/RobertTLange/evosax/blob/main/evosax/strategies/indep_iamalgam.py
Reference: Bosman et al. (2013) - https://tinyurl.com/y9fcccx2
"""

def __init__(
self,
param_size: int,
pop_size: int,
elite_ratio: float = 0.35,
full_covariance: bool = False,
eta_sigma: Optional[float] = None,
eta_shift: Optional[float] = None,
init_stdev: float = 0.01,
decay_stdev: float = 0.999,
limit_stdev: float = 0.001,
w_decay: float = 0.0,
seed: int = 0,
logger: logging.Logger = None,
):
"""Initialization function.
Args:
param_size - Parameter size.
pop_size - Population size.
elite_ratio - Population elite fraction used for mean update.
full_covariance - Whether to estimate full covariance or only diag.
eta_sigma - Lrate for covariance (use default if not provided).
eta_shift - Lrate for mean shift (use default if not provided).
init_stdev - Initial scale of Gaussian perturbation.
decay_stdev - Multiplicative scale decay between tell iterations.
limit_stdev - Smallest scale (clipping limit).
w_decay - L2 weight regularization coefficient.
seed - Random seed for parameters sampling.
logger - Logger.
"""

# Delayed importing of evosax

if sys.version_info.minor < 7:
print(
"evosax, which is needed by iAMaLGaM, requires"
" python>=3.7"
)
print(" please consider upgrading your Python version.")
sys.exit(1)

try:
import evosax
except ModuleNotFoundError:
print("You need to install evosax for its iAMaLGaM:")
print(" pip install evosax")
sys.exit(1)

# Set up object variables.

if logger is None:
self.logger = create_logger(name="iAMaLGaM")
else:
self.logger = logger

self.param_size = param_size
self.pop_size = abs(pop_size)
self.rand_key = jax.random.PRNGKey(seed=seed)

# Instantiate evosax's iAMaLGaM - choice between full cov & diagonal
if full_covariance:
self.es = evosax.Full_iAMaLGaM(
popsize=pop_size, num_dims=param_size, elite_ratio=elite_ratio
)
else:
self.es = evosax.Indep_iAMaLGaM(
popsize=pop_size, num_dims=param_size, elite_ratio=elite_ratio
)

# Set hyperparameters according to provided inputs
self.es_params = self.es.default_params.replace(
sigma_init=init_stdev,
sigma_decay=decay_stdev,
sigma_limit=limit_stdev,
init_min=0.0,
init_max=0.0,
)

# Only replace learning rates for mean shift and sigma if provided!
if eta_shift is not None:
self.es_params = self.es_params.replace(eta_shift=eta_shift)
if eta_sigma is not None:
self.es_params = self.es_params.replace(eta_sigma=eta_sigma)

# Initialize the evolution strategy state
self.rand_key, init_key = jax.random.split(self.rand_key)
self.es_state = self.es.initialize(init_key, self.es_params)

# By default evojax assumes maximization of fitness score!
# Evosax, on the other hand, minimizes!
self.fit_shaper = evosax.FitnessShaper(w_decay=w_decay, maximize=True)

def ask(self) -> jnp.ndarray:
self.rand_key, ask_key = jax.random.split(self.rand_key)
self.params, self.es_state = self.es.ask(
ask_key, self.es_state, self.es_params
)
return self.params

def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None:
# Reshape fitness to conform with evosax minimization
fit_re = self.fit_shaper.apply(self.params, fitness)
self.es_state = self.es.tell(
self.params, fit_re, self.es_state, self.es_params
)

@property
def best_params(self) -> jnp.ndarray:
return jnp.array(self.es_state.mean, copy=True)

@best_params.setter
def best_params(self, params: Union[np.ndarray, jnp.ndarray]) -> None:
self.es_state = self.es_state.replace(
best_member=jnp.array(params, copy=True),
mean=jnp.array(params, copy=True),
)
29 changes: 18 additions & 11 deletions evojax/algo/open_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,19 @@ def __init__(
)

# Set hyperparameters according to provided inputs
self.es_params = self.es.default_params
for k, v in optimizer_config.items():
self.es_params[k] = v
self.es_params["sigma_init"] = init_stdev
self.es_params["sigma_decay"] = decay_stdev
self.es_params["sigma_limit"] = limit_stdev
self.es_params["init_min"] = 0.0
self.es_params["init_max"] = 0.0
# Set hyperparameters according to provided inputs
self.es_params = self.es.default_params.replace(
sigma_init=init_stdev,
sigma_decay=decay_stdev,
sigma_limit=limit_stdev,
init_min=0.0,
init_max=0.0,
)

# Update optimizer-specific parameters of Adam
self.es_params = self.es_params.replace(
opt_params=self.es_params.opt_params.replace(**optimizer_config)
)

# Initialize the evolution strategy state
self.rand_key, init_key = jax.random.split(self.rand_key)
Expand Down Expand Up @@ -122,9 +127,11 @@ def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None:

@property
def best_params(self) -> jnp.ndarray:
return jnp.array(self.es_state["mean"], copy=True)
return jnp.array(self.es_state.mean, copy=True)

@best_params.setter
def best_params(self, params: Union[np.ndarray, jnp.ndarray]) -> None:
self.es_state["best_member"] = jnp.array(params, copy=True)
self.es_state["mean"] = jnp.array(params, copy=True)
self.es_state = self.es_state.replace(
best_member=jnp.array(params, copy=True),
mean=jnp.array(params, copy=True),
)
11 changes: 6 additions & 5 deletions evojax/algo/sep_cma_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def __init__(
)

# Set hyperparameters according to provided inputs
self.es_params = self.es.default_params
self.es_params["sigma_init"] = init_stdev
self.es_params = self.es.default_params.replace(sigma_init=init_stdev)

# Initialize the evolution strategy state
self.rand_key, init_key = jax.random.split(self.rand_key)
Expand All @@ -99,9 +98,11 @@ def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None:

@property
def best_params(self) -> jnp.ndarray:
return jnp.array(self.es_state["mean"], copy=True)
return jnp.array(self.es_state.mean, copy=True)

@best_params.setter
def best_params(self, params: Union[np.ndarray, jnp.ndarray]) -> None:
self.es_state["best_member"] = jnp.array(params, copy=True)
self.es_state["mean"] = jnp.array(params, copy=True)
self.es_state = self.es_state.replace(
best_member=jnp.array(params, copy=True),
mean=jnp.array(params, copy=True),
)
7 changes: 4 additions & 3 deletions evojax/sim_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,9 @@ def _scan_loop_eval(self,
scores = jnp.mean(scores.ravel().reshape((-1, n_repeats)), axis=-1)

# Note: QD methods do not support ma_training for now.
final_states = jax.tree_map(
lambda x: x.reshape((scores.shape[0], n_repeats, *x.shape[1:])),
final_states)
if not self._ma_training:
final_states = jax.tree_map(
lambda x: x.reshape((scores.shape[0], n_repeats, *x.shape[1:])),
final_states)

return scores, self._bd_summarize_fn(final_states)
Loading

0 comments on commit 74d4c70

Please sign in to comment.