diff --git a/clrs/_src/processors.py b/clrs/_src/processors.py index a28414b0..9e7fccde 100644 --- a/clrs/_src/processors.py +++ b/clrs/_src/processors.py @@ -339,7 +339,7 @@ def d2_forward(self, adj_mat=adj_mat, hidden=hidden ) - emb_values.append(cell_embedding) + emb_values.append(cell_embedding[0]) ret_nodes.append( jnp.mean(jnp.stack(emb_values, axis=0), axis=0) )