Skip to content

Commit

Permalink
ENH: Add custom JSON encoder for RocketPy objects
Browse files Browse the repository at this point in the history
- This improves the file saving/reading methods
- Sorry for implementing a new feature in this PR, it it feels like the right thing to do right now. I'm running late sorry.
  • Loading branch information
Gui-FernandesBR committed Mar 15, 2024
1 parent 5fefdf9 commit ccf32a4
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 82 deletions.
43 changes: 43 additions & 0 deletions rocketpy/_encoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Defines a custom JSON encoder for RocketPy objects."""

import json
import types

import numpy as np

from rocketpy.mathutils.function import Function


class RocketPyEncoder(json.JSONEncoder):
"""NOTE: This is still under construction, please don't use it yet."""

def default(self, o):
if isinstance(
o,
(
np.int_,
np.intc,
np.intp,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
),
):
return int(o)
elif isinstance(o, (np.float_, np.float16, np.float32, np.float64)):
return float(o)
elif isinstance(o, np.ndarray):
return o.tolist()
elif hasattr(o, "to_dict"):
return o.to_dict()
# elif isinstance(o, Function):
# return o.__dict__()
elif isinstance(o, (Function, types.FunctionType)):
return repr(o)
else:
return json.JSONEncoder.default(self, o)
102 changes: 20 additions & 82 deletions rocketpy/simulation/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import simplekml

from rocketpy._encoders import RocketPyEncoder
from rocketpy.plots.monte_carlo_plots import _MonteCarloPlots
from rocketpy.prints.monte_carlo_prints import _MonteCarloPrints
from rocketpy.simulation.flight import Flight
Expand Down Expand Up @@ -260,8 +261,9 @@ def __export_flight_data(
}

# Write flight setting and results to file
input_file.write(f"{inputs_dict}\n")
output_file.write(f"{results}\n")
# TODO: use json.dumps (requires custom JSONEncoder child class)
input_file.write(f"{json.dumps(inputs_dict, cls=RocketPyEncoder)}\n")
output_file.write(f"{json.dumps(results, cls=RocketPyEncoder)}\n")

def __check_export_list(self, export_list):
"""Checks if the export_list is valid and returns a valid list. If no
Expand Down Expand Up @@ -421,80 +423,29 @@ def error_file(self, value):
# setters for post simulation attributes
def set_inputs_log(self):
"""Sets inputs_log from a file into an attribute for easy access"""
# TODO: add pickle package to deal with parachute triggers and Function objects
self.inputs_log = []
with open(self.input_file, mode="r", encoding="utf-8") as inputs:
# Loop through each line in the file
for line in inputs:
# swap "<" and ">" to "'"
# this is done to interpret the trigger functions
# Find the index of the first and last occurrences of '<' and '>'
first_lt_index = line.find("<")
last_gt_index = line.rfind(">")
if first_lt_index != -1 and last_gt_index != -1:
# Replace the first '<' and last '>' with double quotes
line = (
line[:first_lt_index]
+ '"'
+ line[first_lt_index + 1 : last_gt_index]
+ '"'
+ line[last_gt_index + 1 :]
)

# # Skip comments lines
if line[0] != "{":
continue
# Try to convert the line to a dictionary
try:
d = json.loads(line)
# If successful, append the dictionary to the list
self.inputs_log.append(d)
except json.JSONDecodeError:
continue
with open(self.input_file, mode="r", encoding="utf-8") as rows:
for line in rows:
self.inputs_log.append(json.loads(line))

def set_outputs_log(self):
"""Sets outputs_log from a file into an attribute for easy access"""
self.outputs_log = []
# Loop through each line in the file
with open(self.output_file, mode="r", encoding="utf-8") as outputs:
for line in outputs:
# Skip comments lines
if line[0] != "{":
continue
# Try to convert the line to a dictionary
try:
# If successful, append the dictionary to the list
self.outputs_log.append(json.loads(line))
except json.JSONDecodeError:
continue
with open(self.output_file, mode="r", encoding="utf-8") as rows:
for line in rows:
self.outputs_log.append(json.loads(line))

def set_errors_log(self):
"""Sets errors_log log from a file into an attribute for easy access"""
self.errors_log = []
# Loop through each line in the file
with open(self.error_file, mode="r", encoding="utf-8") as errors:
for line in errors:
# Skip comments lines
if line[0] != "{":
continue
# Try to convert the line to a dictionary
try:
# If successful, append the dictionary to the list
self.errors_log.append(json.loads(line))
except json.JSONDecodeError:
continue
self.errors_log.append(json.loads(line))

def set_num_of_loaded_sims(self):
"""Number of simulations loaded from output_file being currently used."""
# Calculate the number of flights simulated
self.num_of_loaded_sims = 0
# Loop through each line in the file
with open(self.output_file, mode="r", encoding="utf-8") as outputs:
for line in outputs:
# Skip comments lines
if line[0] != "{":
continue
self.num_of_loaded_sims += 1
self.num_of_loaded_sims = sum(1 for _ in outputs)

def set_results(self):
"""Monte carlo results organized in a dictionary where the keys are the
Expand Down Expand Up @@ -531,25 +482,19 @@ def import_outputs(self, filename=None):
-------
None
"""
# select file to use
filepath = filename if filename else self.filename

try:
with open(f"{filepath}.outputs.txt", "r+", encoding="utf-8"):
self.output_file = f"{filepath}.outputs.txt"
# Print the number of flights simulated
print(
f"A total of {self.num_of_loaded_sims} simulations results were loaded from"
f" the following output file: {filepath}.outputs.txt\n"
)
except FileNotFoundError:
with open(filepath, "r+", encoding="utf-8"):
self.output_file = filepath
# Print the number of flights simulated
print(
f"A total of {self.num_of_loaded_sims} simulations results were loaded from"
f" the following output file: {filepath}\n"
)

print(
f"A total of {self.num_of_loaded_sims} simulations results were "
f"loaded from the following output file: {self.output_file}\n"
)

def import_inputs(self, filename=None):
"""Import monte carlo results from .txt file and save it into a
Expand All @@ -565,19 +510,16 @@ def import_inputs(self, filename=None):
-------
None
"""
# select file to use
filepath = filename if filename else self.filename

try:
with open(f"{filepath}.inputs.txt", "r+", encoding="utf-8"):
self.input_file = f"{filepath}.inputs.txt"
# Print the number of flights simulated
print(f"The following input file was imported: {filepath}.inputs.txt\n")
except FileNotFoundError:
with open(filepath, "r+", encoding="utf-8"):
self.input_file = filepath
# Print the number of flights simulated
print(f"The following input file was imported: {filepath}\n")

print(f"The following input file was imported: {self.input_file}\n")

def import_errors(self, filename=None):
"""Import monte carlo results from .txt file and save it into a
Expand All @@ -593,19 +535,15 @@ def import_errors(self, filename=None):
-------
None
"""
# select file to use
filepath = filename if filename else self.filename

try:
with open(f"{filepath}.errors.txt", "r+", encoding="utf-8"):
self.error_file = f"{filepath}.errors.txt"
# Print the number of flights simulated
print(f"The following error file was imported: {filepath}.errors.txt\n")
except FileNotFoundError:
with open(filepath, "r+", encoding="utf-8"):
self.error_file = filepath
# Print the number of flights simulated
print(f"The following error file was imported: {filepath}\n")
print(f"The following error file was imported: {self.error_file}\n")

def import_results(self, filename=None):
"""Import monte carlo results from .txt file and save it into a
Expand Down

0 comments on commit ccf32a4

Please sign in to comment.