Skip to content

Commit

Permalink
Generate plots
Browse files Browse the repository at this point in the history
  • Loading branch information
gmarupilla committed Mar 26, 2021
1 parent 08ef10f commit 068107b
Showing 1 changed file with 96 additions and 1 deletion.
97 changes: 96 additions & 1 deletion cli_util/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import pandas as pd
import shutil
import tempfile

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import libsedml as lsed
from libsedml import SedReport, SedPlot2D

def exec_sed_doc(omex_file_path, base_out_path):
archive_tmp_dir = tempfile.mkdtemp()
Expand Down Expand Up @@ -73,8 +77,99 @@ def transpose_vcml_csv(csv_file_path: str):
df[final_cols].transpose().to_csv(csv_file_path, header=False, index=False)


def get_all_dataref_and_curves(sedml_path):
all_plot_curves = {}
all_report_dataref = {}

sedml = lsed.readSedML(sedml_path)

for output in sedml.getListOfOutputs():
if type(output) == SedPlot2D:
all_curves = {}
for curve in output.getListOfCurves():
all_curves[curve.getId()] = {
'x': curve.getXDataReference(),
'y': curve.getYDataReference()
}
all_plot_curves[output.getId()] = all_curves
if type(output) == SedReport:
for dataset in output.getListOfDataSets():

######
if output.getId() in all_report_dataref:

all_report_dataref[output.getId()].append({
'data_reference': dataset.getDataReference(),
'data_label': dataset.getLabel()
})
else:
all_report_dataref[output.getId()] = []
all_report_dataref[output.getId()].append({
'data_reference': dataset.getDataReference(),
'data_label': dataset.getLabel()
})


return all_report_dataref, all_plot_curves



def get_report_label_from_data_ref(dataref: str, all_report_dataref):
for report in all_report_dataref.keys():
for data_ref in all_report_dataref[report]:
if dataref == data_ref['data_reference']:
return report, data_ref['data_label']


### Update plots dict

def update_dataref_with_report_label(all_report_dataref, all_plot_curves):

for plot,curves in all_plot_curves.items():
for curve_name,datarefs in curves.items():
new_ref = dict(datarefs)
new_ref['x'] = get_report_label_from_data_ref(datarefs['x'], all_report_dataref)[1]
new_ref['y'] = get_report_label_from_data_ref(datarefs['y'], all_report_dataref)[1]
new_ref['report'] = get_report_label_from_data_ref(datarefs['y'], all_report_dataref)[0]
curves[curve_name] = new_ref

return all_report_dataref, all_plot_curves



def get_report_dataframes(all_report_dataref, result_out_dir):
report_frames = {}
reports_list = list(set(all_report_dataref.keys()))
for report in reports_list:
report_frames[report] = pd.read_csv(os.path.join(result_out_dir, report + ".csv")).T.reset_index()
report_frames[report].columns = report_frames[report].iloc[0].values
report_frames[report].drop(index = 0, inplace=True)
return report_frames

## PLOTTING

def plot_and_save_curves(all_plot_curves, report_frames, result_out_dir):
all_plots = dict(all_plot_curves)
for plot, curve_dat in all_plots.items():
dims = (12, 8)
fig, ax = plt.subplots(figsize=dims)
for curve, data in curve_dat.items():
df = report_frames[data['report']]
sns.lineplot(x=df[data['x']].astype(np.float), y=df[data['y']].astype(np.float), ax=ax, label=curve)
ax.set_ylabel('')
# plt.show()
plt.savefig(os.path.join(result_out_dir, plot + '.pdf'), dpi=300)

def gen_plot_pdfs(sedml_path, result_out_dir):
all_report_dataref, all_plot_curves = get_all_dataref_and_curves(sedml_path)
all_report_dataref, all_plot_curves = update_dataref_with_report_label(all_report_dataref, all_plot_curves)
report_frames = get_report_dataframes(all_report_dataref, result_out_dir)
plot_and_save_curves(all_plot_curves, report_frames, result_out_dir)


if __name__ == "__main__":
fire.Fire({
'execSedDoc': exec_sed_doc,
'transposeVcmlCsv': transpose_vcml_csv,
'genPlotPdfs': gen_plot_pdfs,
})

0 comments on commit 068107b

Please sign in to comment.