From f5d10dcb01e4342572935bb788a616a4ae32948a Mon Sep 17 00:00:00 2001 From: nbouziani Date: Sun, 21 Jan 2024 05:47:36 +0000 Subject: [PATCH] Fix flake8 --- .../generate_data_stokes_tutorial.py | 3 +-- .../dataset_processing/pde_dataset.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/physics_driven_ml/dataset_processing/generate_data_stokes_tutorial.py b/physics_driven_ml/dataset_processing/generate_data_stokes_tutorial.py index 735b6c1..02b9270 100644 --- a/physics_driven_ml/dataset_processing/generate_data_stokes_tutorial.py +++ b/physics_driven_ml/dataset_processing/generate_data_stokes_tutorial.py @@ -1,6 +1,5 @@ import os import argparse -import numpy as np from typing import Union, Callable from tqdm.auto import tqdm, trange from numpy.random import default_rng @@ -67,7 +66,7 @@ def solve_stokes_cylinder(fs): # "mat_type": "aij", # "pc_type": "lu", # "ksp_atol": 1.0e-9, - # "pc_factor_mat_solver_type": "mumps"}) + # "pc_factor_mat_solver_type": "mumps"}) return us, edge_index diff --git a/physics_driven_ml/dataset_processing/pde_dataset.py b/physics_driven_ml/dataset_processing/pde_dataset.py index da33d08..ffdd02c 100644 --- a/physics_driven_ml/dataset_processing/pde_dataset.py +++ b/physics_driven_ml/dataset_processing/pde_dataset.py @@ -110,3 +110,23 @@ def __getitem__(self, idx: int) -> GraphBatchElement: return GraphBatchElement(u=target_u, f=target_f, u_fd=target_u_fd, f_fd=target_f_fd) + def collate(self, batch_elements: List[GraphBatchElement]) -> GraphBatchedElement: + # Workaround to enable custom data types (e.g. firedrake.Function) in PyTorch dataloaders + # See: https://pytorch.org/docs/stable/data.html#working-with-collate-fn + batch_size = len(batch_elements) + n = max(e.u.size(-1) for e in batch_elements) + m = max(e.f.size(-1) for e in batch_elements) + + u = torch.zeros(batch_size, n, dtype=batch_elements[0].u.dtype) + f = torch.zeros(batch_size, m, dtype=batch_elements[0].f.dtype) + f_fd = [] + u_fd = [] + for i, e in enumerate(batch_elements): + u[i, :] = e.u + f[i, :] = e.f + u_fd.append(e.u_fd) + f_fd.append(e.f_fd) + + return GraphBatchedElement(u=u, f=f, + u_fd=u_fd, f_fd=f_fd, + batch_elements=batch_elements)