Skip to content

Commit

Permalink
Inital implementation of callback function with some quick fixes of e…
Browse files Browse the repository at this point in the history
…xisting functions for compatability
  • Loading branch information
Morgan Thomas committed Oct 6, 2024
1 parent 054f893 commit 97f8146
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions rocketpy/simulation/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ class MonteCarlo:
"""

def __init__(
self, filename, environment, rocket, flight, export_list=None
self, filename, environment, rocket, flight, export_list=None,
export_function=None
): # pylint: disable=too-many-statements
"""
Initialize a MonteCarlo object.
Expand All @@ -104,6 +105,11 @@ def __init__(
`out_of_rail_stability_margin`, `out_of_rail_time`,
`out_of_rail_velocity`, `max_mach_number`, `frontal_surface_wind`,
`lateral_surface_wind`. Default is None.
export_function : callable, optional
A function which gets called at the end of a simulation to collect
additional data to be exported that isn't pre-defined. Takes the
Flight object as an argument and returns a dictionary. Default is None.
Returns
-------
Expand Down Expand Up @@ -132,6 +138,7 @@ def __init__(
self._last_print_len = 0 # used to print on the same line

self.export_list = self.__check_export_list(export_list)
self.export_function = export_function

try:
self.import_inputs()
Expand Down Expand Up @@ -359,6 +366,13 @@ def __export_flight_data(
for export_item in self.export_list
}

if self.export_function is not None:
additional_exports = self.export_function(flight)
for key in additional_exports.keys():
if key in self.export_list:
raise ValueError(f"Invalid export function, returns dict which overwrites key, '{key}'")
results = results | additional_exports

input_file.write(json.dumps(inputs_dict, cls=RocketPyEncoder) + "\n")
output_file.write(json.dumps(results, cls=RocketPyEncoder) + "\n")

Expand Down Expand Up @@ -654,9 +668,12 @@ def set_processed_results(self):
"""
self.processed_results = {}
for result, values in self.results.items():
mean = np.mean(values)
stdev = np.std(values)
self.processed_results[result] = (mean, stdev)
try:
mean = np.mean(values)
stdev = np.std(values)
self.processed_results[result] = (mean, stdev)
except TypeError:
self.processed_results[result] = (None, None)

# Import methods

Expand Down

0 comments on commit 97f8146

Please sign in to comment.