From c6d1eb30d270f61b796ac2fab7cd515ee29eafa6 Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Tue, 16 Apr 2024 17:35:05 +0200 Subject: [PATCH] Fixup navierstokes. --- src/continuiti/benchmarks/navierstokes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/continuiti/benchmarks/navierstokes.py b/src/continuiti/benchmarks/navierstokes.py index fb136e01..66c4e304 100644 --- a/src/continuiti/benchmarks/navierstokes.py +++ b/src/continuiti/benchmarks/navierstokes.py @@ -69,12 +69,12 @@ def __init__(self, dir: Optional[str] = None): ls = torch.linspace(-1, 1, 64) tx = torch.linspace(-0.9, 0.0, 10) grid_x = torch.meshgrid(ls, ls, tx, indexing="ij") - x = torch.stack(grid_x, axis=3).reshape(1, -1, 3).repeat(1200, 1, 1) + x = torch.stack(grid_x, axis=0).unsqueeze(0).expand(1200, -1, -1, -1, -1) x = x.reshape(1200, 3, 64, 64, 10) ty = torch.linspace(0.1, 1.0, 10) grid_y = torch.meshgrid(ls, ls, ty, indexing="ij") - y = torch.stack(grid_y, axis=3).reshape(1, -1, 3).repeat(1200, 1, 1) + y = torch.stack(grid_y, axis=0).unsqueeze(0).expand(1200, -1, -1, -1, -1) y = y.reshape(1200, 3, 64, 64, 10) # Load vorticity