Skip to content

Commit

Permalink
Merge pull request #255 from AllenNeuralDynamics:tests+docs-treadmill
Browse files Browse the repository at this point in the history
Add bonsai deserialization tests
  • Loading branch information
bruno-f-cruz authored Jun 20, 2024
2 parents 6b350aa + 35cfb63 commit 64cce52
Show file tree
Hide file tree
Showing 16 changed files with 412 additions and 229 deletions.
93 changes: 64 additions & 29 deletions examples/examples.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import os

import aind_behavior_services.rig as rig
import aind_behavior_services.task_logic.distributions as distributions
Expand All @@ -16,19 +17,25 @@
Measurement,
WaterValveCalibration,
WaterValveCalibrationInput,
WaterValveCalibrationOutput,
)
from aind_behavior_services.session import AindBehaviorSessionModel
from aind_behavior_vr_foraging.rig import AindVrForagingRig, RigCalibration, Treadmill
from aind_behavior_vr_foraging.rig import (
AindVrForagingRig,
HarpTreadmill,
RigCalibration,
Treadmill,
)
from aind_behavior_vr_foraging.task_logic import (
AindVrForagingTaskLogic,
AindVrForagingTaskParameters,
)


def main():
# Create a new Session instance

example_session = AindBehaviorSessionModel(
def mock_session() -> AindBehaviorSessionModel:
"""Generates a mock AindBehaviorSessionModel model"""
return AindBehaviorSessionModel(
date=datetime.datetime.now(tz=datetime.timezone.utc),
experiment="AindVrForaging",
root_path="c://",
remote_path="c://remote",
Expand All @@ -40,10 +47,9 @@ def main():
experimenter=["Foo", "Bar"],
)

# Create a new Rig instance

# Create calibrations

def mock_rig() -> AindVrForagingRig:
"""Generates a mock AindVrForagingRig model"""
olfactometer_calibration = OlfactometerCalibration(
output=OlfactometerCalibrationOutput(),
date=datetime.datetime.now(),
Expand Down Expand Up @@ -84,15 +90,16 @@ def main():

water_valve_input = WaterValveCalibrationInput(
measurements=[
Measurement(valve_open_interval=0.2, valve_open_time=0.1, water_weight=[0.1, 0.1], repeat_count=200),
Measurement(valve_open_interval=0.2, valve_open_time=1.0, water_weight=[1, 1], repeat_count=200),
Measurement(valve_open_interval=1, valve_open_time=1, water_weight=[1, 1], repeat_count=200),
Measurement(valve_open_interval=2, valve_open_time=2, water_weight=[2, 2], repeat_count=200),
]
)
water_valve_calibration = WaterValveCalibration(
input=water_valve_input, output=water_valve_input.calibrate_output(), calibration_date=datetime.datetime.now()
)
water_valve_calibration.output = WaterValveCalibrationOutput(slope=1, offset=0) # For testing purposes

example_rig = AindVrForagingRig(
return AindVrForagingRig(
rig_name="test_rig",
triggered_camera_controller=rig.CameraController[rig.SpinnakerCamera](
frame_rate=120,
Expand All @@ -108,15 +115,21 @@ def main():
harp_clock_generator=rig.HarpClockGenerator(port_name="COM6"),
harp_analog_input=None,
harp_sniff_detector=rig.HarpSniffDetector(port_name="COM7"),
treadmill=Treadmill(
harp_board=rig.HarpTreadmill(port_name="COM8"),
settings=rig.Treadmill(wheel_diameter=15, pulses_per_revolution=28800),
harp_treadmill=HarpTreadmill(
port_name="COM8",
calibration=Treadmill(
wheel_diameter=15,
pulses_per_revolution=28800,
break_lookup_calibration=[[0, 0], [1, 65535]],
),
),
screen=rig.Screen(display_index=1),
calibration=RigCalibration(water_valve=water_valve_calibration, olfactometer=olfactometer_calibration),
)

# Create a new TaskLogic instance

def mock_task_logic() -> AindVrForagingTaskLogic:
"""Generates a mock AindVrForagingTaskLogic model"""

def NumericalUpdaterParametersHelper(initial_value, increment, decrement, minimum, maximum):
return vr_task_logic.NumericalUpdaterParameters(
Expand Down Expand Up @@ -164,32 +177,43 @@ def OperantLogicHelper(stop_duration: float = 0.2, is_operant: bool = False):
grace_distance_threshold=10,
)

def ExponentialDistributionHelper(rate: 1, minimum: 0, maximum: 1000):
def ExponentialDistributionHelper(rate=1, minimum=0, maximum=1000):
return distributions.ExponentialDistribution(
distribution_parameters=distributions.ExponentialDistributionParameters(rate=rate),
truncation_parameters=distributions.TruncationParameters(min=minimum, max=maximum, is_truncated=True),
scaling_parameters=distributions.ScalingParameters(scale=1.0, offset=0.0),
)

def Reward_VirtualSiteGeneratorHelper(contrast: float = 0.5):
def Reward_VirtualSiteGeneratorHelper(contrast: float = 0.5, friction: float = 0):
return vr_task_logic.VirtualSiteGenerator(
render_specification=vr_task_logic.RenderSpecification(contrast=contrast),
label=vr_task_logic.VirtualSiteLabels.REWARDSITE,
length_distribution=ExponentialDistributionHelper(1, 0, 10),
treadmill_specification=vr_task_logic.TreadmillSpecification(friction=vr_task_logic.scalar_value(friction)),
)

def InterSite_VirtualSiteGeneratorHelper(contrast: float = 0.5):
def InterSite_VirtualSiteGeneratorHelper(contrast: float = 0.5, friction: float = 0):
return vr_task_logic.VirtualSiteGenerator(
render_specification=vr_task_logic.RenderSpecification(contrast=contrast),
label=vr_task_logic.VirtualSiteLabels.INTERSITE,
length_distribution=ExponentialDistributionHelper(1, 0, 10),
treadmill_specification=vr_task_logic.TreadmillSpecification(friction=vr_task_logic.scalar_value(friction)),
)

def InterPatch_VirtualSiteGeneratorHelper(contrast: float = 1):
def InterPatch_VirtualSiteGeneratorHelper(contrast: float = 1, friction: float = 0):
return vr_task_logic.VirtualSiteGenerator(
render_specification=vr_task_logic.RenderSpecification(contrast=contrast),
label=vr_task_logic.VirtualSiteLabels.INTERPATCH,
length_distribution=ExponentialDistributionHelper(1, 0, 10),
treadmill_specification=vr_task_logic.TreadmillSpecification(friction=vr_task_logic.scalar_value(friction)),
)

def PostPatch_VirtualSiteGeneratorHelper(contrast: float = 1, friction: float = 0.5):
return vr_task_logic.VirtualSiteGenerator(
render_specification=vr_task_logic.RenderSpecification(contrast=contrast),
label=vr_task_logic.VirtualSiteLabels.POSTPATCH,
length_distribution=ExponentialDistributionHelper(1, 0, 10),
treadmill_specification=vr_task_logic.TreadmillSpecification(friction=vr_task_logic.scalar_value(friction)),
)

reward_function = vr_task_logic.PatchRewardFunction(
Expand All @@ -212,6 +236,7 @@ def InterPatch_VirtualSiteGeneratorHelper(contrast: float = 1):
inter_patch=InterPatch_VirtualSiteGeneratorHelper(),
inter_site=InterSite_VirtualSiteGeneratorHelper(),
reward_site=Reward_VirtualSiteGeneratorHelper(),
post_patch=PostPatch_VirtualSiteGeneratorHelper(),
),
)

Expand All @@ -235,9 +260,8 @@ def InterPatch_VirtualSiteGeneratorHelper(contrast: float = 1):
first_state=None, transition_matrix=vr_task_logic.Matrix2D(data=[[1, 0], [0, 1]]), patches=[patch1, patch2]
)

example_vr_task_logic = AindVrForagingTaskLogic(
return AindVrForagingTaskLogic(
task_parameters=AindVrForagingTaskParameters(
name="vr_foraging_task_stage_foo",
rng_seed=None,
updaters=updaters,
environment_statistics=environment_statistics,
Expand All @@ -246,18 +270,29 @@ def InterPatch_VirtualSiteGeneratorHelper(contrast: float = 1):
)
)


def mock_subject_database() -> db.SubjectDataBase:
"""Generates a mock database object"""
database = db.SubjectDataBase()
database.add_subject("test", db.SubjectEntry(task_logic_target="preward_intercept_stageA"))
database.add_subject("test2", db.SubjectEntry(task_logic_target="does_notexist"))
return database


def main(path_seed: str = "./local/{schema}.json"):

example_session = mock_session()
example_rig = mock_rig()
example_task_logic = mock_task_logic()
example_database = mock_subject_database()

os.makedirs(os.path.dirname(path_seed), exist_ok=True)

models = [example_task_logic, example_session, example_rig, example_database]

with open("local/example_vr_task_logic.json", "w") as f:
f.write(example_vr_task_logic.model_dump_json(indent=2))
with open("local/example_session.json", "w") as f:
f.write(example_session.model_dump_json(indent=2))
with open("local/example_rig.json", "w") as f:
f.write(example_rig.model_dump_json(indent=2))
with open("local/example_database.json", "w") as f:
f.write(database.model_dump_json(indent=2))
for model in models:
with open(path_seed.format(schema=model.__class__.__name__), "w", encoding="utf-8") as f:
f.write(model.model_dump_json(indent=2))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ readme = "README.md"
dynamic = ["version"]

dependencies = [
"aind_behavior_services@git+https://github.com/AllenNeuralDynamics/Aind.Behavior.Services@0.7.8",
"aind_behavior_services@git+https://github.com/AllenNeuralDynamics/Aind.Behavior.Services@0.7.9",
]

[project.optional-dependencies]
Expand Down
1 change: 0 additions & 1 deletion src/DataSchemas/aind_behavior_vr_foraging/rig.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class Treadmill(rig.Treadmill):
default=[[0, 0], [1, 65535]],
validate_default=True,
min_length=2,
min=0,
description="Break lookup calibration. Each Tuple is (0-1 (percent), 0-full-scale). \
Values are linearly interpolated",
)
Expand Down
17 changes: 10 additions & 7 deletions src/DataSchemas/aind_behavior_vr_foraging/task_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ class RenderSpecification(BaseModel):
contrast: Optional[float] = Field(default=None, ge=0, le=1, description="Contrast of the texture")


class TreadmillSpecification(BaseModel):
friction: Optional[distributions.Distribution] = Field(
default=None,
description="Friction of the treadmill (0-1). The drawn value must be between 0 and 1",
)


class VirtualSiteGenerator(BaseModel):
render_specification: RenderSpecification = Field(
default=RenderSpecification(), description="Contrast of the environment", validate_default=True
Expand All @@ -179,6 +186,9 @@ class VirtualSiteGenerator(BaseModel):
length_distribution: distributions.Distribution = Field(
default=scalar_value(20), description="Distribution of the length of the virtual site", validate_default=True
)
treadmill_specification: Optional[TreadmillSpecification] = Field(
default=None, description="Treadmill specification", validate_default=True
)


class VirtualSiteGeneration(BaseModel):
Expand All @@ -204,13 +214,6 @@ class VirtualSiteGeneration(BaseModel):
)


class TreadmillSpecification(BaseModel):
friction: Optional[distributions.Distribution] = Field(
default=None,
description="Friction of the treadmill (0-1). The drawn value must be between 0 and 1",
)


class VirtualSite(BaseModel):
id: int = Field(default=0, ge=0, description="Id of the virtual site")
label: VirtualSiteLabels = Field(default=VirtualSiteLabels.UNSPECIFIED, description="Label of the virtual site")
Expand Down
1 change: 0 additions & 1 deletion src/DataSchemas/aind_vr_foraging_rig.json
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,6 @@
"minItems": 2,
"type": "array"
},
"min": 0,
"minItems": 2,
"title": "Break Lookup Calibration",
"type": "array"
Expand Down
43 changes: 24 additions & 19 deletions src/DataSchemas/aind_vr_foraging_task_logic.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,6 @@
],
"title": "Rng Seed"
},
"stage_alias": {
"default": null,
"description": "Alias name used for the task stage",
"oneOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Stage Alias"
},
"updaters": {
"additionalProperties": {
"$ref": "#/definitions/NumericalUpdater"
Expand Down Expand Up @@ -1173,7 +1160,8 @@
},
"render_specification": {
"contrast": null
}
},
"treadmill_specification": null
},
"inter_patch": {
"label": "InterPatch",
Expand All @@ -1190,7 +1178,8 @@
},
"render_specification": {
"contrast": null
}
},
"treadmill_specification": null
},
"post_patch": null,
"reward_site": {
Expand All @@ -1208,7 +1197,8 @@
},
"render_specification": {
"contrast": null
}
},
"treadmill_specification": null
}
},
"description": "Virtual site generation specification"
Expand Down Expand Up @@ -1925,7 +1915,8 @@
"family": "Scalar",
"scaling_parameters": null,
"truncation_parameters": null
}
},
"treadmill_specification": null
},
"description": "Generator of the inter-site virtual sites"
},
Expand All @@ -1950,7 +1941,8 @@
"family": "Scalar",
"scaling_parameters": null,
"truncation_parameters": null
}
},
"treadmill_specification": null
},
"description": "Generator of the inter-patch virtual sites"
},
Expand Down Expand Up @@ -1987,7 +1979,8 @@
"family": "Scalar",
"scaling_parameters": null,
"truncation_parameters": null
}
},
"treadmill_specification": null
},
"description": "Generator of the reward-site virtual sites"
}
Expand Down Expand Up @@ -2035,6 +2028,18 @@
"scaling_parameters": null
},
"description": "Distribution of the length of the virtual site"
},
"treadmill_specification": {
"default": null,
"description": "Treadmill specification",
"oneOf": [
{
"$ref": "#/definitions/TreadmillSpecification"
},
{
"type": "null"
}
]
}
},
"title": "VirtualSiteGenerator",
Expand Down
5 changes: 3 additions & 2 deletions src/Extensions/AddGapSite.bonsai
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<WorkflowBuilder Version="2.8.1"
<WorkflowBuilder Version="2.8.2"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:rx="clr-namespace:Bonsai.Reactive;assembly=Bonsai.Core"
xmlns:p1="clr-namespace:AindVrForagingDataSchema.TaskLogic;assembly=Extensions"
Expand All @@ -17,6 +17,7 @@
<Property Name="OdorSpecification" />
<Property Name="RewardSpecification" />
<Property Name="LengthDistribution" />
<Property Name="TreadmillSpecification" />
</Expression>
<Expression xsi:type="rx:SelectMany">
<Name>AddGapSite</Name>
Expand Down Expand Up @@ -107,10 +108,10 @@
<Property Name="Label" />
<Property Name="OdorSpecification" />
<Property Name="RewardSpecification" />
<Property Name="TreadmillSpecification" />
</Expression>
<Expression xsi:type="IncludeWorkflow" Path="Extensions\CreateSite.bonsai">
<Id>388</Id>
<Label>GapSite</Label>
<Length>20.7234077</Length>
<StartPosition>0</StartPosition>
</Expression>
Expand Down
Loading

0 comments on commit 64cce52

Please sign in to comment.