Skip to content

Commit

Permalink
Use median to normalise. Save eval to json.
Browse files Browse the repository at this point in the history
  • Loading branch information
Chia-Man Hung committed Aug 14, 2023
1 parent 1cedaed commit 338d747
Showing 1 changed file with 41 additions and 13 deletions.
54 changes: 41 additions & 13 deletions sunerf/evaluation/density_cube_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,34 @@
from sunpy.map import Map
from datetime import datetime
import pickle
import json

from sunerf.evaluation.loader import SuNeRFLoader
from sunerf.utilities.data_loader import normalize_datetime

START_STEPNUM = 37 # 5
END_STEPNUM = 37 # 74
'''
python -m sunerf.evaluation.density_cube_eval
'''

START_STEPNUM = 5
END_STEPNUM = 74
CHUNKS = 4

# R_SUN_CM = 6.957e+10
# GRID_SIZE = 500 / 16 # solar radii

ignore_half_of_r = True

# og-eda and og-eda-3
ckpt_dirname = "HAO_pinn_cr_2view_a26978f_heliographic_reformat"
# og-eda
# ckpt_dirname = "HAO_pinn_2view_no_physics"
# ckpt_dirname = "HAO_pinn_2view_cr"
# og-eda-3
# ckpt_dirname = "HAO_pinn_2view_cr3"
# ckpt_dirname = "HAO_pinn_1view_cr3"


def save_stepnum_to_datetime():
stepnum_to_datetime = dict()

Expand Down Expand Up @@ -45,23 +62,25 @@ def dtstr_to_datetime(dtstr):

mae_all_stepnums = []

for stepnum in range(START_STEPNUM, END_STEPNUM + 1, 1):
for stepnum in tqdm(range(START_STEPNUM, END_STEPNUM + 1, 1)):

# load ground truth
gt_fname = "/mnt/ground-data/density_cube/dens_stepnum_%03d.sav" % stepnum
o = scipy.io.readsav(gt_fname)
ph = o['ph1d'] # (258,)
th = o['th1d'] # (128,)
r = o['r1d'] # (256,)
density_gt = o['dens'] # (258, 128, 256) (phi, theta, r)


# ignore half of r
# r_size = len(o['r1d'])
# r = o['r1d'][:int(r_size / 2)] # (256,) -> (128, 0)
# density_gt = o['dens'][:,:,:int(r_size / 2)] # (258, 128, 256) (phi, theta, r)
if ignore_half_of_r:
r_size = len(o['r1d'])
r = o['r1d'][:int(r_size / 2)] # (256,) -> (128, 0)
density_gt = o['dens'][:,:,:int(r_size / 2)] # (258, 128, 256) (phi, theta, r)
else:
r = o['r1d'] # (256,)
density_gt = o['dens'] # (258, 128, 256) (phi, theta, r)

# load model checkpoint
base_path = '/mnt/training/HAO_pinn_cr_2view_a26978f_heliographic_reformat'
base_path = '/mnt/training/' + ckpt_dirname
chk_path = os.path.join(base_path, 'save_state.snf')
loader = SuNeRFLoader(chk_path, resolution=512)

Expand Down Expand Up @@ -100,8 +119,8 @@ def dtstr_to_datetime(dtstr):
# density *= GRID_SIZE ** (-2) * R_SUN_CM ** (-3)

# compare density to ground truth
rel_density = density / np.mean(density)
rel_density_gt = density_gt / np.mean(density_gt)
rel_density = density / np.median(density)
rel_density_gt = density_gt / np.median(density_gt)

print(rel_density[0])
print(rel_density_gt[0])
Expand All @@ -110,4 +129,13 @@ def dtstr_to_datetime(dtstr):

mae_all_stepnums.append(mae)

print(sum(mae_all_stepnums) / len(mae_all_stepnums))
print(mae_all_stepnums)
mae_avg = sum(mae_all_stepnums) / len(mae_all_stepnums)
print(mae_avg)


# save eval to json
output_fname = "eval_half.json" if ignore_half_of_r else "eval.json"
eval_dict = {"mae_all_stepnums": mae_all_stepnums, "mae_avg": mae_avg}
with open(os.path.join(base_path, output_fname), 'w') as fp:
json.dump(eval_dict, fp)

0 comments on commit 338d747

Please sign in to comment.