Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

5 fix add node mutation bug #6

Merged
merged 4 commits into from
Jul 3, 2024
Merged

5 fix add node mutation bug #6

merged 4 commits into from
Jul 3, 2024

Conversation

RPegoud
Copy link
Owner

@RPegoud RPegoud commented Jul 3, 2024

Corrected the initialization of network input activations at initialization and inference time:

def get_initial_activations(inputs: chex.Array, senders: chex.Array) -> chex.Array:
    """Initializes ActivationState.values based on an array of inputs."""
    default_values = jnp.zeros_like(senders)
    keys = jnp.arange(len(inputs))
    matches = jnp.searchsorted(keys, senders)
    mask = (matches < len(keys)) & (keys[matches] == senders)
    return jnp.where(mask, inputs[matches], default_values)


def create_toggled_mask(
    node_types: chex.Array, senders: chex.Array, input_size: int
) -> chex.Array:
    """Initializes the ActivationState.toggled mask based on input nodes"""
    input_nodes_indices = jnp.where(node_types == 0, size=input_size)[0]
    sender_mask = jnp.isin(senders, input_nodes_indices)
    return jnp.int32(sender_mask)


def activation_state_from_inputs(
    inputs: chex.Array,
    senders: chex.Array,
    node_types: chex.Array,
    input_size: int,
    max_nodes: int,
) -> "ActivationState":
    """
    Resets the ActivationState in prevision of a forward pass or a depth scan.

    Args:
        inputs (chex.Array): The activation values of the network's input nodes
        max_nodes (int): The maximum capacity of the network

    Returns:
        ActivationState: The reset ActivationState with:

            - ``values``: initialized based on inputs
            - ``toggled``: input neurons toggled
            - ``activation_counts``: set to zero
            - ``has_fired``: set to zero
            - ``outdated_depths``: True
    """

    values = get_initial_activations(inputs, senders)
    toggled = create_toggled_mask(node_types, senders, input_size)

    return ActivationState(
        values=values,
        toggled=toggled,
        activation_counts=jnp.zeros(max_nodes, dtype=jnp.int32),
        has_fired=jnp.zeros(max_nodes, dtype=jnp.int32),
        node_depths=jnp.zeros(max_nodes, dtype=jnp.int32),
        outdated_depths=True,
    )

@RPegoud RPegoud linked an issue Jul 3, 2024 that may be closed by this pull request
@RPegoud RPegoud merged commit 4f19442 into main Jul 3, 2024
@RPegoud RPegoud deleted the 5-fix-add_node-mutation-bug branch July 3, 2024 16:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix add_node mutation bug
1 participant