diff --git a/azure/batch/Dockerfile b/azure/batch/Dockerfile index 91618f4..02c5041 100644 --- a/azure/batch/Dockerfile +++ b/azure/batch/Dockerfile @@ -1,18 +1,26 @@ -FROM nvidia/cuda:11.3.1-base-ubuntu20.04 +FROM nvidia/cuda:12.1.1-base-ubuntu22.04 -ENV DEBIAN_FRONTEND noninteractive +ENV DEBIAN_FRONTEND=noninteractive WORKDIR /usr/src/zoobot +# Install prerequisites and add deadsnakes PPA for Python 3.10 RUN apt-get update && apt-get -y upgrade && \ apt-get install --no-install-recommends -y \ + software-properties-common && \ + add-apt-repository ppa:deadsnakes/ppa && \ + apt-get update && apt-get install --no-install-recommends -y \ build-essential \ - python3 \ + python3.10 \ + python3.10-distutils \ + python3.10-dev \ python3-pip \ git && \ apt-get clean && rm -rf /var/lib/apt/lists/* -RUN ln -s /usr/bin/python3 /usr/bin/python +# Link Python 3.10 as default +RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \ + ln -sf /usr/bin/python3.10 /usr/bin/python3 # install a newer version of pip # as we can't use the use the ubuntu package pip version (20.0.2) @@ -23,6 +31,7 @@ RUN apt-get remove -y python3-pip RUN ln -s /usr/local/bin/pip3 /usr/bin/pip RUN ln -s /usr/local/bin/pip3 /usr/bin/pip3 -# install our dependencies (see setup.py) + +# Install project dependencies (see setup.py) COPY setup.py . -RUN pip install . --extra-index-url https://download.pytorch.org/whl/cu113 +RUN pip install . --extra-index-url https://download.pytorch.org/whl/cu121 \ No newline at end of file diff --git a/azure/batch/scripts/predict_on_catalog.py b/azure/batch/scripts/predict_on_catalog.py index 9ba4cb8..33aa36c 100644 --- a/azure/batch/scripts/predict_on_catalog.py +++ b/azure/batch/scripts/predict_on_catalog.py @@ -64,15 +64,7 @@ def __getitem__(self, idx): # ensure we raise other response errors like 404 and 500 etc # Note: we don't retry on errors that aren't in the `status_forcelist`, instead we fast fail! response.raise_for_status() - url_mime_type = response.headers['content-type'] - # handle PNG images - if url_mime_type == 'image/png': - # use PIL image to read the png file buffer - image = Image.open(response.raw) - else: # but assume all other images are JPEG - # HWC PIL image - image = Image.fromarray( - galaxy_dataset.decode_jpeg(response.raw.read())) + image = Image.open(response.raw) except Exception as e: # add some logging on the failed url logging.critical('Cannot load {}'.format(url)) @@ -98,7 +90,7 @@ def __getitem__(self, idx): class PredictionGalaxyDataModule(galaxy_datamodule.GalaxyDataModule): # override the setup method to setup our prediction dataset on the prediction catalog def setup(self, stage: Optional[str] = None): - self.predict_dataset = PredictionGalaxyDataset(catalog=self.predict_catalog, transform=self.transform) + self.predict_dataset = PredictionGalaxyDataset(catalog=self.predict_catalog, transform=self.test_transform) def save_predictions_to_json(predictions: np.ndarray, image_ids: List[str], label_cols: List[str], save_loc: str): diff --git a/azure/batch/scripts/train_model_finetune_on_catalog.py b/azure/batch/scripts/train_model_finetune_on_catalog.py index 4b6cc32..3c80a0f 100644 --- a/azure/batch/scripts/train_model_finetune_on_catalog.py +++ b/azure/batch/scripts/train_model_finetune_on_catalog.py @@ -6,7 +6,7 @@ from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule from zoobot.pytorch.training import finetune -from zoobot.shared.schemas import cosmic_dawn_ortho_schema +from zoobot.shared.schemas import cosmic_dawn_ortho_schema, euclid_ortho_schema if __name__ == '__main__': @@ -41,11 +41,7 @@ schema_dict = { 'cosmic_dawn': cosmic_dawn_ortho_schema, - 'euclid': { - 'label_cols': ['smooth-or-featured-euclid_smooth', 'smooth-or-featured-euclid_featured-or-disk', 'smooth-or-featured-euclid_problem', 'disk-edge-on-euclid_yes', 'disk-edge-on-euclid_no', 'has-spiral-arms-euclid_yes', 'has-spiral-arms-euclid_no', 'bar-euclid_strong', 'bar-euclid_weak', 'bar-euclid_no', 'bulge-size-euclid_dominant', 'bulge-size-euclid_large', 'bulge-size-euclid_moderate', 'bulge-size-euclid_small', 'bulge-size-euclid_none', 'how-rounded-euclid_round', 'how-rounded-euclid_in-between', 'how-rounded-euclid_cigar-shaped', 'edge-on-bulge-euclid_boxy', 'edge-on-bulge-euclid_none', 'edge-on-bulge-euclid_rounded', 'spiral-winding-euclid_tight', 'spiral-winding-euclid_medium', 'spiral-winding-euclid_loose', 'spiral-arm-count-euclid_1', 'spiral-arm-count-euclid_2', 'spiral-arm-count-euclid_3', 'spiral-arm-count-euclid_4', 'spiral-arm-count-euclid_more-than-4', 'spiral-arm-count-euclid_cant-tell', 'merging-euclid_none', 'merging-euclid_minor-disturbance', 'merging-euclid_major-disturbance', 'merging-euclid_merger', 'clumps-euclid_yes', 'clumps-euclid_no', 'problem-euclid_star', 'problem-euclid_artifact', 'problem-euclid_zoom', 'artifact-euclid_satellite', 'artifact-euclid_scattered', 'artifact-euclid_diffraction', 'artifact-euclid_ray', 'artifact-euclid_saturation', 'artifact-euclid_other', 'artifact-euclid_ghost'], - 'questions': ['smooth-or-featured-euclid', 'indices 0 to 2', 'asked after None', 'disk-edge-on-euclid', 'indices 3 to 4', 'asked after smooth-or-featured-euclid_featured-or-disk', 'index 1', 'has-spiral-arms-euclid', 'indices 5 to 6', 'asked after disk-edge-on-euclid_no', 'index 4', 'bar-euclid', 'indices 7 to 9', 'asked after disk-edge-on-euclid_no', 'index 4', 'bulge-size-euclid', 'indices 10 to 14', 'asked after disk-edge-on-euclid_no', 'index 4', 'how-rounded-euclid', 'indices 15 to 17',' asked after smooth-or-featured-euclid_smooth', 'index 0', 'edge-on-bulge-euclid', 'indices 18 to 20', 'asked after disk-edge-on-euclid_yes', 'index 3', 'spiral-winding-euclid', 'indices 21 to 23', 'asked after has-spiral-arms-euclid_yes', 'index 5', 'spiral-arm-count-euclid', 'indices 24 to 29', 'asked after has-spiral-arms-euclid_yes', 'index 5', 'merging-euclid', 'indices 30 to 33', 'asked after None', 'clumps-euclid', 'indices 34 to 35', 'asked after disk-edge-on-euclid_no', 'index 4', 'problem-euclid', 'indices 36 to 38', 'asked after smooth-or-featured-euclid_problem', 'index 2', 'artifact-euclid', 'indices 39 to 45', 'asked after problem-euclid_artifact', 'index 37'], - 'question_answer_pairs': {'smooth-or-featured-euclid': ['_smooth', '_featured-or-disk', '_problem'], 'disk-edge-on-euclid': ['_yes', '_no'], 'has-spiral-arms-euclid': ['_yes', '_no'], 'bar-euclid': ['_strong', '_weak', '_no'], 'bulge-size-euclid': ['_dominant', '_large', '_moderate', '_small', '_none'], 'how-rounded-euclid': ['_round', '_in-between', '_cigar-shaped'], 'edge-on-bulge-euclid': ['_boxy', '_none', '_rounded'], 'spiral-winding-euclid': ['_tight', '_medium', '_loose'], 'spiral-arm-count-euclid': ['_1', '_2', '_3', '_4', '_more-than-4', '_cant-tell'], 'merging-euclid': ['_none', '_minor-disturbance', '_major-disturbance', '_merger'], 'clumps-euclid': ['_yes', '_no'], 'problem-euclid': ['_star', '_artifact', '_zoom'], 'artifact-euclid': ['_satellite', '_scattered', '_diffraction', '_ray', '_saturation', '_other', '_ghost']} - } + 'euclid': euclid_ortho_schema } schema = schema_dict.get(args.schema, cosmic_dawn_ortho_schema) # setup the error reporting tool - https://app.honeybadger.io/projects/ @@ -105,15 +101,12 @@ else: logger = None + # load the model from checkpoint model = finetune.FinetuneableZoobotTree( - checkpoint_loc=args.checkpoint, # params specific to tree finetuning schema=schema, - # params for superclass i.e. any finetuning - encoder_dim=args.encoder_dim, - n_layers=args.n_layers, - prog_bar=args.progress_bar + zoobot_checkpoint_loc=args.checkpoint ) trainer = finetune.get_trainer( diff --git a/azure/batch/setup.py b/azure/batch/setup.py index 4ad1240..8976a8c 100644 --- a/azure/batch/setup.py +++ b/azure/batch/setup.py @@ -16,9 +16,9 @@ "Environment :: GPU :: NVIDIA CUDA" ], packages=setuptools.find_packages(), - python_requires=">=3.7", # tf 2.8.0 requires Python 3.7 and above + python_requires=">=3.9", # tf 2.8.0 requires Python 3.7 and above install_requires=[ - 'zoobot[pytorch_cu113] >= 1.0', # the big cheese - bring in the zoobot! + 'zoobot[pytorch-cu121] >= 2.0.0', # the big cheese - bring in the zoobot! 'requests >= 2.28.1', # used to download prediction images from a remote URL 'honeybadger' # used for error reporting ]