Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Sep 3, 2024
1 parent a4752b6 commit e4634e7
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def test_gnn_edge_attrs(
shape=multi_agent_obs.shape[len(batch_size) :]
),
"pos": UnboundedContinuousTensorSpec(
shape=multi_agent_obs.shape[len(batch_size) :]
shape=multi_agent_pos.shape[len(batch_size) :]
),
},
shape=(n_agents,),
Expand Down Expand Up @@ -361,6 +361,7 @@ def test_gnn_edge_attrs(
gnn_kwargs=None,
position_key=position_key,
exclude_pos_from_node_features=False,
pos_features=pos_size if position_key is not None else 0,
).get_model(
input_spec=input_spec,
output_spec=output_spec,
Expand Down Expand Up @@ -392,6 +393,7 @@ def test_gnn_edge_attrs(
gnn_kwargs=None,
position_key=position_key,
exclude_pos_from_node_features=False,
pos_features=pos_size if position_key is not None else 0,
).get_model(
input_spec=input_spec,
output_spec=output_spec,
Expand Down

0 comments on commit e4634e7

Please sign in to comment.