Skip to content

Commit

Permalink
Merge pull request #112 from ikrommyd/support-coffea-casa
Browse files Browse the repository at this point in the history
feat: support coffea-casa for execution
  • Loading branch information
ikrommyd authored Oct 23, 2024
2 parents d4bea53 + c420adb commit 1e1c4a1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
40 changes: 38 additions & 2 deletions scripts/run_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ def main():
runner_utils.set_binning(runner_utils.load_json(args.binning))
fileset = runner_utils.load_json(args.fileset)
logger.info(f"Loaded fileset from {args.fileset}")

if args.executor == "dask/casa" or args.executor.startswith("tls://"):
# use xcache for coffea-casa
xrootd_pfx = "root://"
xrd_pfx_len = len(xrootd_pfx)
for dataset in fileset.keys():
files = fileset[dataset]["files"]
newfiles = {}
for path, value in files.items():
newpath = path.replace(path[xrd_pfx_len : xrd_pfx_len + path[xrd_pfx_len:].find("/store")], "xcache/")
newfiles[newpath] = value
fileset[dataset]["files"] = newfiles

if args.preprocess:
from coffea.dataset_tools import preprocess

Expand All @@ -63,9 +76,10 @@ def main():
with gzip.open("/tmp/preprocessed_fileset.json.gz", "wt") as f:
logger.info("Saving the preprocessed fileset to /tmp/preprocessed_fileset.json.gz")
json.dump(fileset, f, indent=2)

instance = runner_utils.initialize_class(config, args, fileset)

if args.port is not None:
if args.port is not None and (not args.executor.startswith("tls://") and not args.executor.startswith("tcp://") and not args.executor.startswith("ucx://")):
if not runner_utils.check_port(args.port):
logger.error(f"Port {args.port} is occupied in this node. Try another one.")
raise ValueError(f"Port {args.port} is occupied in this node. Try another one.")
Expand Down Expand Up @@ -130,6 +144,18 @@ def main():
],
)
scheduler = "distributed"
elif args.executor == "dask/casa":
from coffea_casa import CoffeaCasaCluster

logger.info("Running using CoffeaCasaCluster")
cluster = CoffeaCasaCluster(
cores=args.cores,
memory=args.memory,
disk=args.disk,
scheduler_options={"port": args.port, "dashboard_address": args.dashboard_address},
log_directory=args.log_directory,
)
scheduler = "distributed"
elif args.executor == "dask/slurm":
from dask_jobqueue import SLURMCluster

Expand Down Expand Up @@ -165,7 +191,7 @@ def main():
scheduler_options={"dashboard_address": args.dashboard_address},
)
scheduler = "distributed"
elif args.executor is not None and (args.executor.startswith("tls:://") or args.executor.startswith("tcp://") or args.executor.startswith("ucx://")):
elif args.executor is not None and (args.executor.startswith("tls://") or args.executor.startswith("tcp://") or args.executor.startswith("ucx://")):
logger.info(f"Will use dask scheduler at {args.executor}")
elif args.executor is None:
logger.info("Running with default dask scheduler")
Expand All @@ -180,6 +206,11 @@ def main():
cluster.scale(args.scaleout)
logger.info(f"Set up cluster {cluster}")
client = Client(cluster)
if args.executor == "dask/casa" or args.executor.startswith("tls://"):
from dask.distributed import PipInstall

plugin = PipInstall(packages=["egamma-tnp@git+https://${TOKEN}@github.com/ikrommyd/egamma-tnp.git@master"])
client.register_plugin(plugin)
logger.info(f"Set up client {client}")
if args.executor is not None and (args.executor.startswith("tls://") or args.executor.startswith("tcp://") or args.executor.startswith("ucx://")):
client = Client(args.executor)
Expand Down Expand Up @@ -208,6 +239,11 @@ def main():
out = runner_utils.process_out(out, args.output)
logger.info(f"Final output after post-processing:\n{out}")
logger.info("Finished the E/Gamma Tag and Probe workflow")
logger.info("Shutting down the client and cluster if any")
if client:
client.shutdown()
if cluster:
cluster.close()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/egamma_tnp/utils/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def save_array_to_parquet(array, output_dir, dataset, subdir, prefix=None, repar

# Trick to reduce node multiplicity before to_parquet until it's fixed.
# TODO: remove this when the issue is fixed
array = array[array.run > -999.0]
array = array[0:]

# Repartition the array if needed
if repartition_n:
Expand Down

0 comments on commit 1e1c4a1

Please sign in to comment.