Skip to content

Commit

Permalink
add a ring demo perf
Browse files Browse the repository at this point in the history
  • Loading branch information
erizmr committed Oct 12, 2024
1 parent 7ad13ee commit f55e32f
Showing 1 changed file with 131 additions and 0 deletions.
131 changes: 131 additions & 0 deletions test_ring_demo_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import time
from types import SimpleNamespace
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
import pickle
import warpmesh as wm

print("Setting up solver.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("!!!!!device!!!!!! ", device)
#################### Load trained model ####################

with open("./pretrain_model/config.yaml", "r") as file:
config_data = yaml.safe_load(file)

config = SimpleNamespace(**config_data)
config.mesh_feat = ["coord", "monitor_val"]

print(config)

model = wm.M2N_T(
deform_in_c=config.num_deform_in,
gfe_in_c=config.num_gfe_in,
lfe_in_c=config.num_lfe_in,
)
model_file_path = "./pretrain_model/model_999.pth"
model = wm.load_model(model, model_file_path)
model.eval()
model = model.to(device)
###########################################################

model_results = "./ring_demo_data/ring_ref_results.pkl"
input_sample_path = "./ring_demo_data/input_sample_data.pkl"
mesh_path = "./ring_demo_data/ring_demo_mesh.msh"

with open(model_results, "rb") as f:
plot_data_dict_model = pickle.load(f)
print(plot_data_dict_model)

with open(input_sample_path, "rb") as f:
input_sample_data = pickle.load(f)
print(input_sample_data)

sample = input_sample_data.to(device)
total_infer_time = 0.0
all_infer_time = []
num_run = 20
with torch.no_grad():
for _ in range(num_run):
start_time = time.perf_counter()
adapted_coord = model(sample)
end_time = time.perf_counter()
curr_infer_time = (end_time - start_time)*1e3
all_infer_time.append(curr_infer_time)
total_infer_time += curr_infer_time
averaged_time = total_infer_time/num_run
print(f"Total model inference time: {total_infer_time} ms, averaged time: {averaged_time}")

# Check result
reference_adapted_mesh = plot_data_dict_model["mesh_model"]
assert np.allclose(adapted_coord.cpu().detach().numpy(), reference_adapted_mesh, rtol=1e-05, atol=1e-08), "Model output mesh is not consistent to the reference"
print("Output is consistent to the reference.")

output_file = "./test_ring_demo_perf_out.txt"
print(all_infer_time)
with open(output_file, "w") as f:
f.write(', '.join([str(v) for v in all_infer_time]))
f.write('\n')
f.write('average time: ' + str(averaged_time) + '\n')
f.write('total time: ' + str(total_infer_time) + '\n')
print(f"write results to {output_file}.")


rows = 3
cols = 2
cmap = "seismic"

fig, ax = plt.subplots(
rows, cols, figsize=(cols * 10, rows * 10), layout="compressed"
)


## Firedrake visualization part
import firedrake as fd

mesh_og = fd.Mesh(mesh_path)
mesh_refer = fd.Mesh(mesh_path)
mesh_model = fd.Mesh(mesh_path)


og_function_space = fd.FunctionSpace(mesh_og, "CG", 1)
model_function_space = fd.FunctionSpace(mesh_model, "CG", 1)
mesh_refer_function_space = fd.FunctionSpace(mesh_refer, "CG", 1)

u_og = fd.Function(fd.FunctionSpace(mesh_og, "CG", 1))
u_ma = fd.Function(fd.FunctionSpace(mesh_refer, "CG", 1))
u_model = fd.Function(fd.FunctionSpace(mesh_model, "CG", 1))
monitor_values = fd.Function(og_function_space)

u_og_data = plot_data_dict_model["u_original"]
u_og.dat.data[:] = u_og_data

rows = 1
cols = 4
cmap = "seismic"
FONT_SIZE = 24

fig, ax = plt.subplots(
rows, cols, figsize=(cols * 10, rows * 10), layout="compressed"
)

fd.triplot(mesh_og, axes=ax[0])
ax[0].set_title("Original Mesh", fontsize=FONT_SIZE)
fd.tripcolor(u_og, axes=ax[1], cmap=cmap)
ax[1].set_title("Solution", fontsize=FONT_SIZE)

# Adapted mesh
mesh_model.coordinates.dat.data[:] = adapted_coord.cpu().detach().numpy()
fd.triplot(mesh_model, axes=ax[2])
ax[2].set_title("Adapated Mesh (UM2N)", fontsize=FONT_SIZE)

mesh_refer.coordinates.dat.data[:] = plot_data_dict_model["mesh_model"]
fd.triplot(mesh_model, axes=ax[3])
ax[3].set_title("Adapated Mesh (UM2N) Reference", fontsize=FONT_SIZE)


plt.savefig("test_ring_demo_perf.png")
plt.show()

0 comments on commit f55e32f

Please sign in to comment.