From 1449ca3379ec7a1263da72d47d5672d1e7715bdb Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 1 May 2024 12:20:11 -0400 Subject: [PATCH] =?UTF-8?q?fix:=20=F0=9F=90=9B=20Fix=20inplace=20change=20?= =?UTF-8?q?of=20outputs=20passed=20to=20model=20forward=20pass?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 1 - src/leibnetz/leibnet.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index e7ea71f..2f505fd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,7 +22,6 @@ include_package_data = True install_requires = torch numpy - funlib.learn.torch @ git+https://github.com/funkelab/funlib.learn.torch.git [options.packages.find] where=src diff --git a/src/leibnetz/leibnet.py b/src/leibnetz/leibnet.py index de5d0da..bb4051c 100644 --- a/src/leibnetz/leibnet.py +++ b/src/leibnetz/leibnet.py @@ -394,7 +394,7 @@ def forward(self, inputs: dict[str, dict[str, Sequence[int | float]]]): # outputs is a dictionary of tensors # initialize buffer - self.buffer = inputs + self.buffer = {key: inputs[key] for key in self.input_keys} # march along nodes based on graph succession for flushable_list, node in zip(self.flushable_arrays, self.ordered_nodes):