Skip to content

Commit

Permalink
added rich HTML output for MC_GEOPHIRES. Also fixed bug for when run …
Browse files Browse the repository at this point in the history
…from command line
  • Loading branch information
malcolm-dsider committed Mar 12, 2024
1 parent 07e7eaa commit 9c9919f
Showing 1 changed file with 197 additions and 2 deletions.
199 changes: 197 additions & 2 deletions src/geophires_monte_carlo/MC_GeoPHIRES3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from rich.console import Console
from rich.table import Table

from geophires_monte_carlo.common import _get_logger
from geophires_x.Parameter import OutputParameter
from geophires_x.Parameter import floatParameter
from geophires_x_client import GeophiresInputParameters
from geophires_x_client import GeophiresXClient
from geophires_x_client import GeophiresXResult
Expand All @@ -33,6 +37,184 @@
from hip_ra_x import HipRaXClient


def Write_HTML_Output(
html_path: str,
df: pd.DataFrame,
outputs: list,
mins: list,
maxs: list,
medians: list,
averages: list,
means: list,
std: list,
full_names: set,
short_names: set,
) -> None:
"""
Write_HTML_Output - write the results of the Monte Carlo simulation to an HTML file
:param html_path: the path to the HTML file to write
:type html_path: str
:param df: the DataFrame with the results
:type df: pd.DataFrame
:param outputs: the list of output variable names
:type outputs: list
:param mins: the list of minimum values for each output variable
:type mins: list
:param maxs: the list of maximum values for each output variable
:type maxs: list
:param medians: the list of median values for each output variable
:type medians: list
:param averages: the list of average values for each output variable
:type averages: list
:param means: the list of mean values for each output variable
:type means: list
:param std: the list of standard deviation values for each output variable
:type std: list
:param full_names: the list of full names for each output variable
:type full_names: set
:param short_names: the list of short names for each output variable
:type short_names: set
"""

# Build the tables that will hold those results, along with the columns for the input variables
results_table = Table(title='GEOPHIRES/HIR-RA Monte Carlo Results')
results_table.add_column('Iteration #', no_wrap=True, justify='center')
for output in df.axes[1]:
results_table.add_column(output.replace(',', ''), no_wrap=True, justify='center')

statistics_table = Table(title='GEOPHIRES/HIR-RA Monte Carlo Statistics')
statistics_table.add_column('Output Parameter Name', no_wrap=True, justify='center')
statistics_table.add_column('minimum', no_wrap=True, justify='center')
statistics_table.add_column('maximum', no_wrap=True, justify='center')
statistics_table.add_column('median', no_wrap=True, justify='center')
statistics_table.add_column('average', no_wrap=True, justify='center')
statistics_table.add_column('mean', no_wrap=True, justify='center')
statistics_table.add_column('standard deviation', no_wrap=True, justify='center')

# Iterate over the rows of the DataFrame and add them to the results table
for index, row in df.iterrows():
data = row.values[0 : len(outputs)]

# have to deal with the special case where thr last column is actually
# a compound string with multiple columns in it that looks like this:
# ' (Gradient 1:47.219846973456924;Reservoir Temperature:264.7789623351493;...)'
str_to_parse = str(row.values[len(outputs)]).strip().replace('(', '').replace(')', '')
fields = str_to_parse.split(';')
for field in fields:
if len(field) > 0:
key, value = field.split(':')
data = np.append(data, float(value))

results_table.add_row(str(int(index)), *[render_default(d) for d in data])

for i in range(len(outputs)):
statistics_table.add_row(
outputs[i],
render_default(mins[i]),
render_default(maxs[i]),
render_default(medians[i]),
render_default(averages[i]),
render_default(means[i]),
render_default(std[i]),
)

console = Console(style='bold white on black', force_terminal=True, record=True, width=500)
console.print(results_table)
console.print(' ')
console.print(statistics_table)
console.save_html(html_path)

# Write a reference to the image(s) into the HTML file by inserting before the "</body>" tag
# build the string to be inserted first
insert_string = ''
for _ in range(len(full_names)):
name_to_use = short_names.pop()
insert_string = insert_string + f'<img src="{name_to_use}.png" alt="{name_to_use}">\n'

match_string = '</body>'
with open(html_path, 'r+', encoding='UTF-8') as html_file:
contents = html_file.readlines()
if match_string in contents[-1]: # Handle last line to prevent IndexError
pass
else:
for index, line in enumerate(contents):
if match_string in line and insert_string not in contents[index + 1]:
contents.insert(index, insert_string)
break
html_file.seek(0)
html_file.writelines(contents)


def UpgradeSymbologyOfUnits(unit: str) -> str:
"""
UpgradeSymbologyOfUnits is a function that takes a string that represents a unit and replaces the **2 and **3
with the appropriate unicode characters for superscript 2 and 3, and replaces "deg" with the unicode character
for degrees.
:param unit: a string that represents a unit
:return: a string that represents a unit with the appropriate unicode characters for superscript 2 and 3, and
replaces "deg" with the unicode character for degrees.
"""
return unit.replace('**2', '\u00b2').replace('**3', '\u00b3').replace('deg', '\u00b0')


def render_default(p: float, unit: str = '') -> str:
"""
RenderDefault - render a float as a string with 2 decimal places, or in scientific notation if it is greater than
10,000 with the unit appended to it if it is not an empty string (the default)
:param p: the float to render
:type p: float
:param unit: the unit to append to the string
:type unit: str
:return: the string representation of the float
:rtype: str
"""
unit = UpgradeSymbologyOfUnits(unit)
# if the number is greater than 10,000, render it in scientific notation
if p > 10_000:
return f'{p:10.2e} {unit}'.strip()
# otherwise, render it with 2 decimal places
else:
return f'{p:10.2f} {unit}'.strip()


def render_scientific(p: float, unit: str = '') -> str:
"""
RenderScientific - render a float as a string in scientific notation with 2 decimal places
and the unit appended to it if it is not an empty string (the default)
:param p: the float to render
:type p: float
:param unit: the unit to append to the string
:type unit: str
:return: the string representation of the float
:rtype: str
"""
unit = UpgradeSymbologyOfUnits(unit)
return f'{p:10.2e} {unit}'.strip()


def render_Parameter_default(p: floatParameter | OutputParameter) -> str:
"""
RenderDefault - render a float as a string with 2 decimal places, or in scientific notation if it is greater than
10,000 with the unit appended to it if it is not an empty string (the default) by calling the render_default base
function
:param p: the parameter to render
:type p: float
:return: the string representation of the float
"""
return render_default(p.value, p.CurrentUnits.value)


def render_parameter_scientific(p: floatParameter | OutputParameter) -> str:
"""
RenderScientific - render a float as a string in scientific notation with 2 decimal places
and the unit appended to it if it is not an empty string (the default) by calling the render_scientific base function
:param p: the parameter to render
:type p: float
:return: the string representation of the float
"""
return render_scientific(p.value, p.CurrentUnits.value)


def check_and_replace_mean(input_value, args) -> list:
"""
CheckAndReplaceMean - check to see if the user has requested that a value be replaced by a mean value by specifying
Expand Down Expand Up @@ -268,7 +450,9 @@ def main(command_line_args=None):
if 'MC_OUTPUT_FILE' in args and args.MC_OUTPUT_FILE is not None
else str(Path(Path(args.Input_file).parent, 'MC_Result.txt').absolute())
)
args.MC_OUTPUT_FILE = output_file
python_path = 'python'
html_path = ''

for line in flist:
clean = line.strip()
Expand All @@ -284,6 +468,8 @@ def main(command_line_args=None):
output_file = pair[1]
elif pair[0].startswith('PYTHON_PATH'):
python_path = pair[1]
elif pair[0].startswith('HTML_PATH'):
html_path = pair[1]

# check to see if there is a "#" in an input, if so, use the results file to replace it with the value
for input_value in inputs:
Expand Down Expand Up @@ -375,6 +561,8 @@ def main(command_line_args=None):
# write them out
annotations = ''
outputs_result: dict[str, dict] = {}
full_names: set = set()
short_names: set = set()
with open(output_file, 'a') as f:
if iterations != actual_records_count:
f.write(
Expand Down Expand Up @@ -408,10 +596,17 @@ def main(command_line_args=None):
f.write(f'bin values (as percentage): {ret[0]!s}\n')
f.write(f'bin edges: {ret[1]!s}\n')
fname = df.columns[i].strip().replace('/', '-')
plt.savefig(Path(Path(output_file).parent, f'{fname}.png'))

save_path = Path(Path(output_file).parent, f'{fname}.png')
if html_path:
save_path = Path(Path(html_path).parent, f'{fname}.png')
plt.savefig(save_path)
full_names.add(save_path)
short_names.add(fname)
annotations = ''

if html_path:
Write_HTML_Output(html_path, df, outputs, mins, maxs, medians, averages, means, std, full_names, short_names)

with open(Path(output_file).with_suffix('.json'), 'w') as json_output_file:
json_output_file.write(json.dumps(outputs_result))
logger.info(f'Wrote JSON results to {json_output_file.name}')
Expand Down

0 comments on commit 9c9919f

Please sign in to comment.