diff --git a/blackbirds/models/rama_cont.py b/blackbirds/models/rama_cont.py index f2406ed..bc5e043 100644 --- a/blackbirds/models/rama_cont.py +++ b/blackbirds/models/rama_cont.py @@ -35,7 +35,7 @@ def initialize(self, params): def step(self, params, x): # draw epsilon_t from normal distribution sigma = params[2] - epsilon_t = torch.distributions.Normal(0, sigma).rsample((self.n_agents,)) + epsilon_t = torch.distributions.Normal(0, sigma).rsample() # compute order nu_t = x[-1, 0, :] order = self.compute_order(epsilon_t, nu_t) diff --git a/requirements.txt b/requirements.txt index 7747c32..e301890 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ grad-june>=0.1.8 networkx>=3.0 normflows>=1.6.2 +numpy==1.26.4 pyyaml>=6.0 tensorboard>=2.12.1 torch>=2.0