Skip to content

Commit

Permalink
Fix flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
nbouziani committed Jan 21, 2024
1 parent 8d222ae commit f5d10dc
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import argparse
import numpy as np
from typing import Union, Callable
from tqdm.auto import tqdm, trange
from numpy.random import default_rng
Expand Down Expand Up @@ -67,7 +66,7 @@ def solve_stokes_cylinder(fs):
# "mat_type": "aij",
# "pc_type": "lu",
# "ksp_atol": 1.0e-9,
# "pc_factor_mat_solver_type": "mumps"})
# "pc_factor_mat_solver_type": "mumps"})
return us, edge_index


Expand Down
20 changes: 20 additions & 0 deletions physics_driven_ml/dataset_processing/pde_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,23 @@ def __getitem__(self, idx: int) -> GraphBatchElement:
return GraphBatchElement(u=target_u, f=target_f,
u_fd=target_u_fd, f_fd=target_f_fd)

def collate(self, batch_elements: List[GraphBatchElement]) -> GraphBatchedElement:
# Workaround to enable custom data types (e.g. firedrake.Function) in PyTorch dataloaders
# See: https://pytorch.org/docs/stable/data.html#working-with-collate-fn
batch_size = len(batch_elements)
n = max(e.u.size(-1) for e in batch_elements)
m = max(e.f.size(-1) for e in batch_elements)

u = torch.zeros(batch_size, n, dtype=batch_elements[0].u.dtype)
f = torch.zeros(batch_size, m, dtype=batch_elements[0].f.dtype)
f_fd = []
u_fd = []
for i, e in enumerate(batch_elements):
u[i, :] = e.u
f[i, :] = e.f
u_fd.append(e.u_fd)
f_fd.append(e.f_fd)

return GraphBatchedElement(u=u, f=f,
u_fd=u_fd, f_fd=f_fd,
batch_elements=batch_elements)

0 comments on commit f5d10dc

Please sign in to comment.