Skip to content

Commit

Permalink
save cfg of trains
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcellocosti committed Jan 3, 2025
1 parent 0b1df2f commit 4455cf4
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions run3/utils/train_output_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ROOT
import uproot
from sklearn.model_selection import train_test_split
import xml.etree.ElementTree as ET
# pylint: disable=no-member

def split_dataset(df, config):
Expand Down Expand Up @@ -158,6 +159,43 @@ def download_aod(input_list_filename, config):
merge_aod.OutputFile(analysis_output_filename, "RECREATE")
merge_aod.Merge()

def download_full_config(input_list_filename, config):
"""
Downloads full_config.json file associated to train run
Args:
input_list_filename (str): The path to the file containing the list of input file names.
config (dict): A dictionary containing configuration parameters.
Returns:
None
"""

# Create folder for output
folder = "MC" if config["isMC"] else "Data"
output_directory = config["output_directory"] + "/" + folder + f"/Train{config['train_run']}"
if not os.path.exists(output_directory):
os.makedirs(output_directory)

os.system(f"alien.py cp {input_list_filename[0]}/Stage_1.xml file:{output_directory}")

# Load and parse the XML file
xml_file = f"{output_directory}/Stage_1.xml"
tree = ET.parse(xml_file)
root = tree.getroot()

# Find the specific event with name="1"
event = root.find(".//event[@name='1']") # XPath to find <event> with attribute name="1"
if event is not None:
file_element = event.find("file")
if file_element is not None:
lfn = os.path.dirname(file_element.get("lfn"))
os.system(f"alien.py cp {lfn}/full_config.json file:{output_directory}")
with open(f"{output_directory}/full_config.json", "r") as infile:
data = json.load(infile)
with open(f"{output_directory}/full_config.json", "w") as outfile:
json.dump(data, outfile, indent=4)
os.remove(xml_file)
print('full_config.json loaded!')

def convert_aod_to_parquet(config): # pylint: disable=too-many-locals
"""
Converts AOD (Analysis Object Data) file to Parquet format.
Expand Down Expand Up @@ -212,7 +250,7 @@ def convert_aod_to_parquet(config): # pylint: disable=too-many-locals
df_eff.to_parquet(output_filename.replace(".parquet", "_Eff.parquet"))
del df

def download_files_from_grid(config, aod=False, analysis=False, parquet=False):
def download_files_from_grid(config, aod=False, analysis=False, parquet=False, full_config=False):
"""
Downloads files from the grid based on the provided configuration.
Expand All @@ -221,6 +259,7 @@ def download_files_from_grid(config, aod=False, analysis=False, parquet=False):
aod (bool, optional): Flag indicating whether to download AOD files.
analysis (bool, optional): Flag indicating whether to download analysis results.
parquet (bool, optional): Flag indicating whether to convert AOD files to Parquet format.
full_config (bool, optional): Flag indicating whether to download the full_config.json file.
Note:
If no flags are provided, all operations are performed.
"""
Expand All @@ -229,6 +268,8 @@ def download_files_from_grid(config, aod=False, analysis=False, parquet=False):
aod = True
analysis = True
parquet = True
if not full_config:
full_config = True

if aod or analysis:
# Get files from grid
Expand All @@ -242,6 +283,8 @@ def download_files_from_grid(config, aod=False, analysis=False, parquet=False):
download_aod(files_to_download, config)
if analysis:
download_analysis_results(files_to_download, config)
if full_config:
download_full_config(files_to_download, config)
if parquet:
convert_aod_to_parquet(config)

Expand All @@ -256,10 +299,12 @@ def download_files_from_grid(config, aod=False, analysis=False, parquet=False):
help='Run only the analysis results download and merge')
parser.add_argument('--parquet', action='store_true', default=False,
help='Run only the conversion to Parquet')
parser.add_argument('--config', action='store_true', default=False,
help='Download only the full_config.json')
args = parser.parse_args()


with open(args.config_file, encoding="utf8") as cfg_file:
cfg = json.load(cfg_file)

download_files_from_grid(cfg, aod=args.aod, analysis=args.analysis, parquet=args.parquet)
download_files_from_grid(cfg, aod=args.aod, analysis=args.analysis, parquet=args.parquet, full_config=args.config)

0 comments on commit 4455cf4

Please sign in to comment.