Skip to content

Commit

Permalink
Merge pull request #16 from AllenNeuralDynamics/feat-force-integral-t…
Browse files Browse the repository at this point in the history
…rial-tyep

Implement force accumulation trial type
  • Loading branch information
bruno-f-cruz authored Aug 12, 2024
2 parents 14232e3 + c781533 commit 98104b0
Show file tree
Hide file tree
Showing 10 changed files with 790 additions and 170 deletions.
50 changes: 40 additions & 10 deletions src/DataSchemas/aind_behavior_force_foraging/task_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ class NumericalUpdater(BaseModel):
)


class UpdateTargetParameterBy(str, Enum):
"""Defines the independent variable used for the update"""

TIME = "Time"
REWARD = "Reward"
TRIAL = "Trial"


class UpdateTargetParameter(str, Enum):
"""Defines the target parameters"""

Expand All @@ -104,14 +112,6 @@ class UpdateTargetParameter(str, Enum):
DELAY = "Delay"


class UpdateTargetParameterBy(str, Enum):
"""Defines the independent variable used for the update"""

TIME = "Time"
REWARD = "Reward"
TRIAL = "Trial"


class ActionUpdater(BaseModel):
target_parameter: UpdateTargetParameter = Field(
default=UpdateTargetParameter.PROBABILITY, description="Target parameter"
Expand All @@ -122,10 +122,19 @@ class ActionUpdater(BaseModel):
updater: NumericalUpdater = Field(..., description="Updater")


class TrialType(str, Enum):
"""Defines the trial types"""

NONE = "None"
ACCUMULATION = "Accumulation"
ROI = "RegionOfInterest"


class HarvestAction(BaseModel):
"""Defines an abstract class for an harvest action"""

action: HarvestActionLabel = Field(default=HarvestActionLabel.NONE, description="Label of the action")
trial_type: TrialType = Field(default=TrialType.NONE, description="Type of the trial")
probability: float = Field(default=1, description="Probability of reward")
amount: float = Field(default=1, description="Amount of reward to be delivered")
delay: float = Field(default=0, description="Delay between successful harvest and reward delivery")
Expand All @@ -134,7 +143,7 @@ class HarvestAction(BaseModel):
default=MAX_LOAD_CELL_FORCE,
le=MAX_LOAD_CELL_FORCE,
ge=-MAX_LOAD_CELL_FORCE,
description="Upper bound of the force target region.",
description="Upper bound of the force target region or the target cached force required.",
)
lower_force_threshold: float = Field(
default=5000,
Expand All @@ -153,13 +162,29 @@ class HarvestAction(BaseModel):
)

@model_validator(mode="after")
def check_passwords_match(self) -> Self:
def _validate_thresholds(self) -> Self:
if self.upper_force_threshold < self.lower_force_threshold:
raise ValueError(
f"Upper force threshold ({self.upper_force_threshold}) must be greater than lower force threshold({self.lower_force_threshold})"
)
return self

@model_validator(mode="after")
def _validate_trial_type(self) -> Self:
if self.trial_type == TrialType.ROI:
if not all(
[
self._between_thresholds(self.upper_force_threshold),
self._between_thresholds(self.lower_force_threshold),
]
):
raise ValueError("Force thresholds must be between -32768 and 32768 for ROI trials")
return self

@staticmethod
def _between_thresholds(value: float) -> bool:
return value <= MAX_LOAD_CELL_FORCE and value >= -MAX_LOAD_CELL_FORCE


class QuiescencePeriod(BaseModel):
"""Defines a quiescence settings"""
Expand Down Expand Up @@ -284,6 +309,11 @@ class ForceLookUpTable(BaseModel):
path: str = Field(
..., description="Reference to the look up table image. Should be a 1 channel image. Value = LUT[Left, Right]"
)

offset: float = Field(default=0, description="Offset to add to the look up table value")

scale: float = Field(default=1, description="Scale to multiply the look up table value")

left_min: float = Field(
..., description="The lower value of Left force used to linearly scale the input coordinate to."
)
Expand Down
35 changes: 34 additions & 1 deletion src/DataSchemas/aind_force_foraging_task_logic.json
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,18 @@
"title": "Path",
"type": "string"
},
"offset": {
"default": 0,
"description": "Offset to add to the look up table value",
"title": "Offset",
"type": "number"
},
"scale": {
"default": 1,
"description": "Scale to multiply the look up table value",
"title": "Scale",
"type": "number"
},
"left_min": {
"description": "The lower value of Left force used to linearly scale the input coordinate to.",
"title": "Left Min",
Expand Down Expand Up @@ -769,6 +781,15 @@
"default": "None",
"description": "Label of the action"
},
"trial_type": {
"allOf": [
{
"$ref": "#/definitions/TrialType"
}
],
"default": "None",
"description": "Type of the trial"
},
"probability": {
"default": 1,
"description": "Probability of reward",
Expand All @@ -795,7 +816,7 @@
},
"upper_force_threshold": {
"default": 32768,
"description": "Upper bound of the force target region.",
"description": "Upper bound of the force target region or the target cached force required.",
"maximum": 32768.0,
"minimum": -32768.0,
"title": "Upper Force Threshold",
Expand Down Expand Up @@ -1392,6 +1413,7 @@
"left_harvest": {
"default": {
"action": "Left",
"trial_type": "None",
"probability": 1.0,
"amount": 1.0,
"delay": 0.0,
Expand All @@ -1415,6 +1437,7 @@
"right_harvest": {
"default": {
"action": "Right",
"trial_type": "None",
"probability": 1.0,
"amount": 1.0,
"delay": 0.0,
Expand All @@ -1439,6 +1462,16 @@
"title": "Trial",
"type": "object"
},
"TrialType": {
"description": "Defines the trial types",
"enum": [
"None",
"Accumulation",
"RegionOfInterest"
],
"title": "TrialType",
"type": "string"
},
"TruncationParameters": {
"properties": {
"is_truncated": {
Expand Down
81 changes: 81 additions & 0 deletions src/Extensions/AccumulationTrialTypeResponse.bonsai
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
<?xml version="1.0" encoding="utf-8"?>
<WorkflowBuilder Version="2.8.5"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:rx="clr-namespace:Bonsai.Reactive;assembly=Bonsai.Core"
xmlns="https://bonsai-rx.org/2018/workflow">
<Workflow>
<Nodes>
<Expression xsi:type="WorkflowInput">
<Name>Source1</Name>
</Expression>
<Expression xsi:type="rx:Defer">
<Name>AccumulationTrialTypeResponse</Name>
<Workflow>
<Nodes>
<Expression xsi:type="WorkflowInput">
<Name>Source1</Name>
</Expression>
<Expression xsi:type="Combinator">
<Combinator xsi:type="rx:Take">
<rx:Count>1</rx:Count>
</Combinator>
</Expression>
<Expression xsi:type="rx:AsyncSubject">
<Name>trialType</Name>
</Expression>
<Expression xsi:type="SubscribeSubject">
<Name>ThisTrial</Name>
</Expression>
<Expression xsi:type="MemberSelector">
<Selector>LeftHarvest</Selector>
</Expression>
<Expression xsi:type="IncludeWorkflow" Path="Extensions\ObserveAccumulationAction.bonsai" />
<Expression xsi:type="SubscribeSubject">
<Name>LeftHasReward</Name>
</Expression>
<Expression xsi:type="Combinator">
<Combinator xsi:type="rx:WithLatestFrom" />
</Expression>
<Expression xsi:type="SubscribeSubject">
<Name>ThisTrial</Name>
</Expression>
<Expression xsi:type="MemberSelector">
<Selector>RightHarvest</Selector>
</Expression>
<Expression xsi:type="IncludeWorkflow" Path="Extensions\ObserveAccumulationAction.bonsai" />
<Expression xsi:type="SubscribeSubject">
<Name>RightHasReward</Name>
</Expression>
<Expression xsi:type="Combinator">
<Combinator xsi:type="rx:WithLatestFrom" />
</Expression>
<Expression xsi:type="Combinator">
<Combinator xsi:type="rx:Merge" />
</Expression>
<Expression xsi:type="WorkflowOutput" />
</Nodes>
<Edges>
<Edge From="0" To="1" Label="Source1" />
<Edge From="1" To="2" Label="Source1" />
<Edge From="3" To="4" Label="Source1" />
<Edge From="4" To="5" Label="Source1" />
<Edge From="5" To="7" Label="Source1" />
<Edge From="6" To="7" Label="Source2" />
<Edge From="7" To="13" Label="Source1" />
<Edge From="8" To="9" Label="Source1" />
<Edge From="9" To="10" Label="Source1" />
<Edge From="10" To="12" Label="Source1" />
<Edge From="11" To="12" Label="Source2" />
<Edge From="12" To="13" Label="Source2" />
<Edge From="13" To="14" Label="Source1" />
</Edges>
</Workflow>
</Expression>
<Expression xsi:type="WorkflowOutput" />
</Nodes>
<Edges>
<Edge From="0" To="1" Label="Source1" />
<Edge From="1" To="2" Label="Source1" />
</Edges>
</Workflow>
</WorkflowBuilder>
Loading

0 comments on commit 98104b0

Please sign in to comment.