diff --git a/azure/batch/scripts/train_model_finetune_on_catalog.py b/azure/batch/scripts/train_model_finetune_on_catalog.py index 90c70bc..4b6cc32 100644 --- a/azure/batch/scripts/train_model_finetune_on_catalog.py +++ b/azure/batch/scripts/train_model_finetune_on_catalog.py @@ -47,7 +47,7 @@ 'question_answer_pairs': {'smooth-or-featured-euclid': ['_smooth', '_featured-or-disk', '_problem'], 'disk-edge-on-euclid': ['_yes', '_no'], 'has-spiral-arms-euclid': ['_yes', '_no'], 'bar-euclid': ['_strong', '_weak', '_no'], 'bulge-size-euclid': ['_dominant', '_large', '_moderate', '_small', '_none'], 'how-rounded-euclid': ['_round', '_in-between', '_cigar-shaped'], 'edge-on-bulge-euclid': ['_boxy', '_none', '_rounded'], 'spiral-winding-euclid': ['_tight', '_medium', '_loose'], 'spiral-arm-count-euclid': ['_1', '_2', '_3', '_4', '_more-than-4', '_cant-tell'], 'merging-euclid': ['_none', '_minor-disturbance', '_major-disturbance', '_merger'], 'clumps-euclid': ['_yes', '_no'], 'problem-euclid': ['_star', '_artifact', '_zoom'], 'artifact-euclid': ['_satellite', '_scattered', '_diffraction', '_ray', '_saturation', '_other', '_ghost']} } } - schema = args.schema + schema = schema_dict.get(args.schema, cosmic_dawn_ortho_schema) # setup the error reporting tool - https://app.honeybadger.io/projects/ honeybadger_api_key = os.getenv('HONEYBADGER_API_KEY') if honeybadger_api_key: @@ -71,7 +71,7 @@ kade_catalog['file_loc'].iloc[len(kade_catalog.index) - 1])) datamodule = GalaxyDataModule( - label_cols=schema_dict[args.schema].label_cols, + label_cols=schema.label_cols, catalog=kade_catalog, batch_size=args.batch_size, num_workers=args.num_workers, @@ -109,7 +109,7 @@ model = finetune.FinetuneableZoobotTree( checkpoint_loc=args.checkpoint, # params specific to tree finetuning - schema=schema_dict[args.schema], + schema=schema, # params for superclass i.e. any finetuning encoder_dim=args.encoder_dim, n_layers=args.n_layers,