Skip to content

Commit 4f19442

Browse files
authoredJul 3, 2024··
Merge pull request #6 from RPegoud/5-fix-add_node-mutation-bug
5 fix add node mutation bug
2 parents 8809be2 + 8bd998d commit 4f19442

35 files changed

+1411
-674
lines changed
 

‎README.md

+8
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,16 @@ Mutations:
2626
* [x] Weight reset
2727
* [x] Add node
2828
* [x] Add connection
29+
* [ ] Mutate activation
2930
* [ ] Wrap all mutations in a single function
3031

32+
Misc:
33+
34+
* [ ] Add Hydra config for constant attributes
35+
* [ ] Separate ``max_nodes`` and ``max_connections``
36+
* [ ] Add bias
37+
* [ ] Set the minimum sender index to 1 instead of 0
38+
3139
Crossing:
3240

3341
* [ ] Add novelty fields to Network dataclass

‎main.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import hydra
2+
import jax
3+
from omegaconf import DictConfig, OmegaConf
4+
5+
from neat_jax import Mutations, forward, init_network, log_config
6+
7+
INPUT_SIZE = 4
8+
OUTPUT_SIZE = 3
9+
10+
11+
def run_exp(config: dict):
12+
config.input_size = INPUT_SIZE
13+
config.output_size = OUTPUT_SIZE
14+
15+
key = jax.random.key(config.params.seed)
16+
inputs = jax.random.normal(key, (config.input_size,))
17+
18+
net, activation_state = init_network(
19+
inputs, config.output_size, config.network.max_nodes, key
20+
)
21+
mutations = Mutations(max_nodes=config.network.max_nodes, **config.mutations)
22+
net = mutations.add_node(key, net)
23+
activation_state, y = forward(inputs, net, config)
24+
25+
print(activation_state, y)
26+
27+
28+
@hydra.main(
29+
config_path="neat_jax/configs",
30+
config_name="default_config.yaml",
31+
version_base="1.3.2",
32+
)
33+
def hydra_entry_point(cfg: DictConfig):
34+
OmegaConf.set_struct(cfg, False)
35+
log_config(cfg)
36+
return run_exp(cfg)
37+
38+
39+
if __name__ == "__main__":
40+
hydra_entry_point()

‎neat_jax/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
from .activation_fns import activation_fns_list, get_activation_fn
2+
from .initialization import activation_state_from_inputs, init_network
13
from .mutations import Mutations
24
from .neat_dataclasses import ActivationState, Network
35
from .nn import (
46
forward,
57
forward_toggled_nodes,
6-
get_activation,
8+
get_activation_fn,
79
get_active_connections,
810
get_required_activations,
9-
make_network,
1011
toggle_receivers,
1112
update_depth,
1213
)
13-
from .utils import plot_network, sample_from_mask
14+
from .utils import cartesian_product, log_config, plot_network, sample_from_mask
186 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
3.03 KB
Binary file not shown.
Binary file not shown.
-2.38 KB
Binary file not shown.
582 Bytes
Binary file not shown.

‎neat_jax/activation_fns.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
activation_fns_list = [
5+
lambda x: jnp.float32(x), # identity
6+
lambda x: 1 / (1 + jnp.exp(-x)), # sigmoid
7+
lambda x: jnp.divide(1, x), # inverse
8+
lambda x: jnp.sinh(x) / jnp.cosh(x), # hyperbolic cosine
9+
lambda x: jnp.float32(jnp.maximum(0, x)), # relu
10+
lambda x: jnp.float32(jnp.abs(x)), # absolute value
11+
lambda x: jnp.sin(x), # sine
12+
lambda x: jnp.exp(jnp.square(-x)), # gaussian
13+
lambda x: jnp.float32(jnp.sign(x)), # step
14+
]
15+
16+
17+
def get_activation_fn(activation_index: int, x: float) -> jnp.float32:
18+
"""
19+
Given an index, selects an activation function and computes `activation(x)`.
20+
21+
```python
22+
0: jnp.float32(x) # identity function
23+
1: 1 / (1 + jnp.exp(-x)), # sigmoid
24+
2: jnp.divide(1, x), # inverse
25+
3: jnp.sinh(x) / jnp.cosh(x), # hyperbolic cosine
26+
4: jnp.float32(jnp.maximum(0, x)), # relu
27+
5: jnp.float32(jnp.abs(x)), # absolute value
28+
6: jnp.sin(x), # sine
29+
7: jnp.exp(jnp.square(-x)), # gaussian
30+
8: jnp.float32(jnp.sign(x)), # step
31+
```
32+
"""
33+
return jax.lax.switch(
34+
activation_index,
35+
activation_fns_list,
36+
operand=x,
37+
)

‎neat_jax/configs/default_config.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
defaults:
2+
- params: base_params
3+
- network: base_net
4+
- mutations: mutations
5+
- env: gymnax/cartpole
6+
- _self_
7+
8+
- override hydra/hydra_logging: disabled
9+
- override hydra/job_logging: disabled
10+
11+
hydra:
12+
output_subdir: null
13+
run:
14+
dir: .
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
suit_name: gymnax
2+
3+
scenario:
4+
name: CartPole-v1
5+
task_name: cartpole
6+
7+
kwargs: {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
weight_shift_rate: 0.5
2+
weight_mutation_rate: 0.1
3+
add_node_rate: 0.03
4+
add_connection_rate: 0.05
5+
activation_fn_mutation_rate: 0.1
+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
max_nodes: 30
2+
max_connections: 45
+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
seed: 0 # RNG seed.
2+
population_size: 1024
3+
total_generations: 1e7

‎neat_jax/initialization.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import chex
2+
import jax
3+
import jax.numpy as jnp
4+
5+
from .neat_dataclasses import ActivationState, Network
6+
from .utils.utils import cartesian_product
7+
8+
9+
def get_initial_activations(inputs: chex.Array, senders: chex.Array) -> chex.Array:
10+
"""Initializes ActivationState.values based on an array of inputs."""
11+
default_values = jnp.zeros_like(senders)
12+
keys = jnp.arange(len(inputs))
13+
matches = jnp.searchsorted(keys, senders)
14+
mask = (matches < len(keys)) & (keys[matches] == senders)
15+
return jnp.where(mask, inputs[matches], default_values)
16+
17+
18+
def create_toggled_mask(
19+
node_types: chex.Array, senders: chex.Array, input_size: int
20+
) -> chex.Array:
21+
"""Initializes the ActivationState.toggled mask based on input nodes"""
22+
input_nodes_indices = jnp.where(node_types == 0, size=input_size)[0]
23+
sender_mask = jnp.isin(senders, input_nodes_indices)
24+
return jnp.int32(sender_mask)
25+
26+
27+
def activation_state_from_inputs(
28+
inputs: chex.Array,
29+
senders: chex.Array,
30+
node_types: chex.Array,
31+
input_size: int,
32+
max_nodes: int,
33+
) -> "ActivationState":
34+
"""
35+
Resets the ActivationState in prevision of a forward pass or a depth scan.
36+
37+
Args:
38+
inputs (chex.Array): The activation values of the network's input nodes
39+
max_nodes (int): The maximum capacity of the network
40+
41+
Returns:
42+
ActivationState: The reset ActivationState with:
43+
44+
- ``values``: initialized based on inputs
45+
- ``toggled``: input neurons toggled
46+
- ``activation_counts``: set to zero
47+
- ``has_fired``: set to zero
48+
- ``outdated_depths``: True
49+
"""
50+
51+
values = get_initial_activations(inputs, senders)
52+
toggled = create_toggled_mask(node_types, senders, input_size)
53+
54+
return ActivationState(
55+
values=values,
56+
toggled=toggled,
57+
activation_counts=jnp.zeros(max_nodes, dtype=jnp.int32),
58+
has_fired=jnp.zeros(max_nodes, dtype=jnp.int32),
59+
node_depths=jnp.zeros(max_nodes, dtype=jnp.int32),
60+
outdated_depths=True,
61+
)
62+
63+
64+
def init_network(
65+
inputs: chex.Array,
66+
output_size: int,
67+
max_nodes: int,
68+
key: chex.PRNGKey,
69+
scale_weights: float = 0.1,
70+
) -> tuple[Network, ActivationState]:
71+
"""Creates a Network and ActivationState from an input array."""
72+
73+
input_size = len(inputs)
74+
n_initial_connections = input_size * output_size
75+
sender_receiver_pairs = cartesian_product(
76+
jnp.arange(input_size),
77+
jnp.arange(input_size, input_size + output_size),
78+
size=max_nodes,
79+
fill_value=-max_nodes,
80+
)
81+
senders = sender_receiver_pairs[:, 0]
82+
receivers = sender_receiver_pairs[:, 1]
83+
84+
weights_init = (
85+
jax.random.normal(key, (n_initial_connections,), dtype=jnp.float32)
86+
* scale_weights
87+
)
88+
weights = jnp.zeros(max_nodes).at[:n_initial_connections].set(weights_init)
89+
90+
node_types = jnp.concatenate(
91+
[
92+
jnp.zeros(input_size, dtype=jnp.int32), # input nodes = 0
93+
jnp.full(output_size, 2, dtype=jnp.int32), # output nodes = 2
94+
jnp.full(max_nodes - input_size - output_size, 3, dtype=jnp.int32),
95+
] # disabled nodes = 3
96+
)
97+
activation_fns = jnp.zeros(max_nodes, dtype=jnp.int32) # no activation by default
98+
99+
activations = get_initial_activations(inputs, senders)
100+
toggled_nodes = create_toggled_mask(node_types, senders, input_size)
101+
activation_counts = jnp.zeros(max_nodes, dtype=jnp.int32)
102+
activation_counts = jnp.zeros(max_nodes, dtype=jnp.int32)
103+
has_fired = jnp.zeros(max_nodes, dtype=jnp.int32)
104+
105+
return (
106+
Network(
107+
node_indices=jnp.arange(max_nodes, dtype=jnp.int32),
108+
node_types=node_types,
109+
weights=weights,
110+
activation_fns=activation_fns,
111+
senders=senders,
112+
receivers=receivers,
113+
input_size=input_size,
114+
output_size=output_size,
115+
max_nodes=max_nodes,
116+
),
117+
ActivationState(
118+
values=activations,
119+
toggled=toggled_nodes,
120+
activation_counts=activation_counts,
121+
has_fired=has_fired,
122+
node_depths=jnp.zeros(max_nodes, dtype=jnp.int32),
123+
outdated_depths=True,
124+
),
125+
)

0 commit comments

Comments
 (0)
Please sign in to comment.