forked from NVIDIA/modulus
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pi_fine_tuning.py
571 lines (485 loc) · 18.1 KB
/
pi_fine_tuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import time, os
import wandb
import hydra
from hydra.utils import to_absolute_path
from omegaconf import DictConfig
try:
import apex
except:
pass
try:
import pyvista as pv
except:
raise ImportError(
"Stokes Dataset requires the pyvista library. Install with "
+ "pip install pyvista"
)
from modulus.models.mlp.fully_connected import FullyConnected
from modulus.launch.logging import (
PythonLogger,
initialize_wandb,
RankZeroLoggingWrapper,
)
from modulus.launch.utils import load_checkpoint, save_checkpoint
from utils import relative_lp_error, get_dataset
from collections import OrderedDict
from sympy import Symbol, Function, Number
from modulus.sym.eq.pde import PDE
from modulus.sym.node import Node
from modulus.sym.key import Key
from modulus.sym.graph import Graph
from modulus.sym.models.fully_connected import FullyConnectedArch
from modulus.sym.models.fourier_net import FourierNetArch
from modulus.sym.models.arch import Arch
from modulus.sym.key import Key
from modulus.sym.manager import GraphManager
from typing import Optional, Dict
class Stokes(PDE):
"""Incompressible Stokes flow"""
def __init__(self, nu, dim=3):
# set params
self.dim = dim
# coordinates
x, y, z = Symbol("x"), Symbol("y"), Symbol("z")
# make input variables
input_variables = {"x": x, "y": y, "z": z}
if self.dim == 2:
input_variables.pop("z")
# velocity componets
u = Function("u")(*input_variables)
v = Function("v")(*input_variables)
if self.dim == 3:
w = Function("w")(*input_variables)
else:
w = Number(0)
# pressure
p = Function("p")(*input_variables)
# kinematic viscosity
if isinstance(nu, str):
nu = Function(nu)(*input_variables)
elif isinstance(nu, (float, int)):
nu = Number(nu)
# set equations
self.equations = {}
self.equations["continuity"] = u.diff(x) + v.diff(y) + w.diff(z)
self.equations["momentum_x"] = +p.diff(x) - nu * (
u.diff(x).diff(x) + u.diff(y).diff(y) + u.diff(z).diff(z)
)
self.equations["momentum_y"] = +p.diff(y) - nu * (
v.diff(x).diff(x) + v.diff(y).diff(y) + v.diff(z).diff(z)
)
self.equations["momentum_z"] = +p.diff(z) - nu * (
w.diff(x).diff(x) + w.diff(y).diff(y) + w.diff(z).diff(z)
)
if self.dim == 2:
self.equations.pop("momentum_z")
class DNN(torch.nn.Module):
"""
Custom PyTorch model
"""
def __init__(self, layers, fourier_features=64):
super(DNN, self).__init__()
# parameters
self.depth = len(layers) - 1
# Fourier features
self.fourier_features = fourier_features
self.register_buffer(
"B", 10 * torch.randn((layers[0], fourier_features))
) # Random matrix
# set up layer order dict
self.activation = torch.nn.GELU
layer_list = list()
for i in range(1, self.depth - 1):
layer_list.append(
("layer_%d" % i, torch.nn.Linear(layers[i], layers[i + 1]))
)
layer_list.append(("activation_%d" % i, self.activation()))
layer_list.append(
("layer_%d" % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]))
)
layerDict = OrderedDict(layer_list)
# deploy layers
self.layers = torch.nn.Sequential(layerDict)
def forward(self, x):
# Add Fourier features
x_proj = torch.matmul(x, self.B)
x_proj = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
# Pass through layers
out = self.layers(x_proj)
return out
class MdlsSymDNN(Arch):
"""
Wrapper model to convert PyTorch model to Modulus-Sym model.
Modulus Sym relies on the inputs/outputs of the model being dictionary of tensors.
This wrapper converts the input dictionary of tensors to a single tensor by
concatenating them along appropriate dimension before passing them as an input to
the pytorch model. During the output, the process is reversed,
the output tensor from pytorch model is split across appropriate dimensions and then
converted to a dictionary with appropriate keys to produce the final output.
The model arguments thus become a list of `Key` objects that informs the model
about the input and output dimensionality of the pytorch model.
For more details on Modulus Sym models, refer:
https://docs.nvidia.com/deeplearning/modulus/modulus-core/tutorials/simple_training_example.html#using-custom-models-in-modulus
For more details on Key class, refer:
https://docs.nvidia.com/deeplearning/modulus/modulus-sym/api/modulus.sym.html#module-modulus.sym.key
"""
def __init__(
self,
input_keys=[Key("x"), Key("y")],
output_keys=[Key("u"), Key("v"), Key("p")],
layers=[2, 128, 128, 128, 128, 3],
fourier_features=64,
):
super(MdlsSymDNN, self).__init__(
input_keys=input_keys,
output_keys=output_keys,
)
self.mdls_model = DNN(layers, fourier_features)
def forward(self, dict_tensor: Dict[str, torch.Tensor]):
# Use concat_input method of the Arch class to convert dict of tensors to
# a single multi-dimensional tensor. Ref: https://github.com/NVIDIA/modulus-sym/blob/main/modulus/sym/models/arch.py#L251
x = self.concat_input(
dict_tensor,
self.input_key_dict,
detach_dict=self.detach_key_dict,
dim=-1,
)
out = self.mdls_model(x)
# Use split_output method of the Arch class to convert a single muli-dimensional
# tensor to a dict of tensors. Ref: https://github.com/NVIDIA/modulus-sym/blob/main/modulus/sym/models/arch.py#L381
return self.split_output(out, self.output_key_dict, dim=1)
class PhysicsInformedFineTuner:
"""
Class to define all the physics informed utils and inference.
"""
def __init__(
self,
device,
gnn_u,
gnn_v,
gnn_p,
coords,
coords_inflow,
coords_noslip,
nu,
ref_u,
ref_v,
ref_p,
):
super(PhysicsInformedFineTuner, self).__init__()
self.device = device
self.nu = nu
self.ref_u = torch.tensor(ref_u).float().to(self.device)
self.ref_v = torch.tensor(ref_v).float().to(self.device)
self.ref_p = torch.tensor(ref_p).float().to(self.device)
self.gnn_u = torch.tensor(gnn_u).float().to(self.device)
self.gnn_v = torch.tensor(gnn_v).float().to(self.device)
self.gnn_p = torch.tensor(gnn_p).float().to(self.device)
self.coords = torch.tensor(coords, requires_grad=True).float().to(self.device)
self.coords_inflow = (
torch.tensor(coords_inflow, requires_grad=True).float().to(self.device)
)
self.coords_noslip = (
torch.tensor(coords_noslip, requires_grad=True).float().to(self.device)
)
self.model = DNN(
layers=[2, 128, 128, 128, 128, 3],
fourier_features=64,
).to(self.device)
self.model = MdlsSymDNN(
input_keys=[Key("x"), Key("y")],
output_keys=[Key("u"), Key("v"), Key("p")],
layers=[2, 128, 128, 128, 128, 3],
fourier_features=64,
).to(self.device)
self.node_pde = Stokes(nu=self.nu, dim=2)
self.nodes = self.node_pde.make_nodes() + [
self.model.make_node(name="flow_network", jit=False)
]
# note: this example uses the Graph class from Modulus Sym to construct the
# computational graph. This allows you to leverage Modulus Sym's optimized
# derivative backend to compute the derivatives, along with other benefits like
# symbolic definition of PDEs and leveraging the PDEs from Modulus Sym's PDE
# module.
# For more details, refer: https://docs.nvidia.com/deeplearning/modulus/modulus-sym/api/modulus.sym.html#module-modulus.sym.graph
self.graph = Graph(
self.nodes,
[Key("x"), Key("y")],
[Key("u"), Key("v"), Key("p")],
func_arch=False,
).to(
self.device
) # For pure inference (no gradients)
self.graph_full = Graph(
self.nodes,
[Key("x"), Key("y")],
[
Key("u"),
Key("v"),
Key("p"),
Key("continuity"),
Key("momentum_x"),
Key("momentum_y"),
],
func_arch=False,
).to(self.device)
self.optimizer = torch.optim.Adam(
self.model.parameters(),
lr=0.001,
fused=True if torch.cuda.is_available() else False,
)
def parabolic_inflow(self, y, U_max=0.3):
u = 4 * U_max * y * (0.4 - y) / (0.4**2)
v = torch.zeros_like(y)
return u, v
def loss(self):
# inflow points
x_in, y_in = self.coords_inflow[:, 0:1], self.coords_inflow[:, 1:2]
results_inflow = self.graph_full.forward({"x": x_in, "y": y_in})
pred_u_in, pred_v_in = results_inflow["u"], results_inflow["v"]
# no-slip points
x_no_slip, y_no_slip = self.coords_noslip[:, 0:1], self.coords_noslip[:, 1:2]
results_noslip = self.graph_full.forward({"x": x_no_slip, "y": y_no_slip})
pred_u_noslip, pred_v_noslip = results_noslip["u"], results_noslip["v"]
# interior points
x_int, y_int = self.coords[:, 0:1], self.coords[:, 1:2]
results_int = self.graph_full.forward({"x": x_int, "y": y_int})
pred_mom_u, pred_mom_v, pred_cont = (
results_int["momentum_x"],
results_int["momentum_y"],
results_int["continuity"],
)
pred_u, pred_v, pred_p = results_int["u"], results_int["v"], results_int["p"]
u_in, v_in = self.parabolic_inflow(self.coords_inflow[:, 1:2])
# Compute losses
# data loss
loss_u = torch.mean((self.gnn_u - pred_u) ** 2)
loss_v = torch.mean((self.gnn_v - pred_v) ** 2)
loss_p = torch.mean((self.gnn_p - pred_p) ** 2)
# inflow boundary condition loss
loss_u_in = torch.mean((u_in - pred_u_in) ** 2)
loss_v_in = torch.mean((v_in - pred_v_in) ** 2)
# noslip boundary condition loss
loss_u_noslip = torch.mean(pred_u_noslip**2)
loss_v_noslip = torch.mean(pred_v_noslip**2)
# pde loss
loss_mom_u = torch.mean(pred_mom_u**2)
loss_mom_v = torch.mean(pred_mom_v**2)
loss_cont = torch.mean(pred_cont**2)
return (
loss_u,
loss_v,
loss_p,
loss_u_in,
loss_v_in,
loss_u_noslip,
loss_v_noslip,
loss_mom_u,
loss_mom_v,
loss_cont,
)
def train(self):
"""PINN based fine-tuning"""
(
loss_u,
loss_v,
loss_p,
loss_u_in,
loss_v_in,
loss_u_noslip,
loss_v_noslip,
loss_mom_u,
loss_mom_v,
loss_cont,
) = self.loss()
# Add custom weights to the different losses. The weights are chosen after
# investigating the relative magnitudes of individual losses and their
# convergence behavior.
loss = (
1 * loss_u
+ 1 * loss_v
+ 1 * loss_p
+ 10 * loss_u_in
+ 10 * loss_v_in
+ 10 * loss_u_noslip
+ 10 * loss_v_noslip
+ 1 * loss_mom_u
+ 1 * loss_mom_v
+ 10 * loss_cont
)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return (
loss_u,
loss_v,
loss_p,
loss_u_in,
loss_v_in,
loss_u_noslip,
loss_v_noslip,
loss_mom_u,
loss_mom_v,
loss_cont,
)
def validation(self):
"""Validation during the PINN fine-tuning step"""
self.model.eval()
with torch.no_grad():
x_int, y_int = self.coords[:, 0:1], self.coords[:, 1:2]
results_int = self.graph.forward({"x": x_int, "y": y_int})
pred_u, pred_v, pred_p = (
results_int["u"],
results_int["v"],
results_int["p"],
)
error_u = torch.linalg.norm(self.ref_u - pred_u) / torch.linalg.norm(
self.ref_u
)
error_v = torch.linalg.norm(self.ref_v - pred_v) / torch.linalg.norm(
self.ref_v
)
error_p = torch.linalg.norm(self.ref_p - pred_p) / torch.linalg.norm(
self.ref_p
)
wandb.log(
{
"test_u_error (%)": error_u.detach().cpu().numpy(),
"test_v_error (%)": error_v.detach().cpu().numpy(),
"test_p_error (%)": error_p.detach().cpu().numpy(),
}
)
return error_u, error_v, error_p
@hydra.main(version_base="1.3", config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
# CUDA support
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# initialize loggers
initialize_wandb(
project="Modulus-Launch",
entity="Modulus",
name="Stokes-Physics-Informed-Fine-Tuning",
group="Stokes-DDP-Group",
mode=cfg.wandb_mode,
)
logger = PythonLogger("main") # General python logger
logger.file_logging()
# Get dataset
path = os.path.join(to_absolute_path(cfg.results_dir), cfg.graph_path)
# get_dataset() function here provides the true values (ref_*) and the gnn
# predictions (gnn_*) along with other data required for the PINN training.
(
ref_u,
ref_v,
ref_p,
gnn_u,
gnn_v,
gnn_p,
coords,
coords_inflow,
coords_outflow,
coords_wall,
coords_polygon,
nu,
) = get_dataset(path)
coords_noslip = np.concatenate([coords_wall, coords_polygon], axis=0)
# Initialize model
pi_fine_tuner = PhysicsInformedFineTuner(
device,
gnn_u,
gnn_v,
gnn_p,
coords,
coords_inflow,
coords_noslip,
nu,
ref_u,
ref_v,
ref_p,
)
logger.info("Inference (with physics-informed training for fine-tuning) started...")
for iters in range(cfg.pi_iters):
# Start timing the iteration
start_iter_time = time.time()
(
loss_u,
loss_v,
loss_p,
loss_u_in,
loss_v_in,
loss_u_noslip,
loss_v_noslip,
loss_mom_u,
loss_mom_v,
loss_cont,
) = pi_fine_tuner.train()
if iters % 100 == 0:
error_u, error_v, error_p = pi_fine_tuner.validation()
# Print losses
logger.info(f"Iteration: {iters}")
logger.info(f"Loss u: {loss_u.detach().cpu().numpy():.3e}")
logger.info(f"Loss v: {loss_v.detach().cpu().numpy():.3e}")
logger.info(f"Loss p: {loss_p.detach().cpu().numpy():.3e}")
logger.info(f"Loss u_in: {loss_u_in.detach().cpu().numpy():.3e}")
logger.info(f"Loss v_in: {loss_v_in.detach().cpu().numpy():.3e}")
logger.info(f"Loss u noslip: {loss_u_noslip.detach().cpu().numpy():.3e}")
logger.info(f"Loss v noslip: {loss_v_noslip.detach().cpu().numpy():.3e}")
logger.info(f"Loss momentum u: {loss_mom_u.detach().cpu().numpy():.3e}")
logger.info(f"Loss momentum v: {loss_mom_v.detach().cpu().numpy():.3e}")
logger.info(f"Loss continuity: {loss_cont.detach().cpu().numpy():.3e}")
# Print errors
logger.info(f"Error u: {error_u:.3e}")
logger.info(f"Error v: {error_v:.3e}")
logger.info(f"Error p: {error_p:.3e}")
# Print iteration time
end_iter_time = time.time()
logger.info(
f"This iteration took {end_iter_time - start_iter_time:.2f} seconds"
)
logger.info("-" * 50) # Add a separator for clarity
logger.info("Physics-informed fine-tuning training completed!")
# Save results
# Final inference call after fine-tuning predictions using the PINN model
with torch.no_grad():
x_int_inf, y_int_inf = (
pi_fine_tuner.coords[:, 0:1],
pi_fine_tuner.coords[:, 1:2],
)
results_int_inf = pi_fine_tuner.graph.forward({"x": x_int_inf, "y": y_int_inf})
pred_u_inf, pred_v_inf, pred_p_inf = (
results_int_inf["u"],
results_int_inf["v"],
results_int_inf["p"],
)
pred_u_inf = pred_u_inf.detach().cpu().numpy()
pred_v_inf = pred_v_inf.detach().cpu().numpy()
pred_p_inf = pred_p_inf.detach().cpu().numpy()
polydata = pv.read(path)
polydata["filtered_u"] = pred_u_inf
polydata["filtered_v"] = pred_v_inf
polydata["filtered_p"] = pred_p_inf
print(path)
polydata.save(path)
logger.info("Inference completed!")
if __name__ == "__main__":
main()