Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade zoobot to latest version #56

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions azure/batch/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
FROM nvidia/cuda:11.3.1-base-ubuntu20.04
FROM nvidia/cuda:12.1.1-base-ubuntu22.04

ENV DEBIAN_FRONTEND noninteractive
ENV DEBIAN_FRONTEND=noninteractive

WORKDIR /usr/src/zoobot

# Install prerequisites and add deadsnakes PPA for Python 3.9
RUN apt-get update && apt-get -y upgrade && \
apt-get install --no-install-recommends -y \
software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && apt-get install --no-install-recommends -y \
build-essential \
python3 \
python3.9 \
python3.9-distutils \
python3.9-dev \
python3-pip \
git && \
apt-get clean && rm -rf /var/lib/apt/lists/*

RUN ln -s /usr/bin/python3 /usr/bin/python
# Link Python 3.9 as default
RUN ln -sf /usr/bin/python3.9 /usr/bin/python && \
ln -sf /usr/bin/python3.9 /usr/bin/python3

# install a newer version of pip
# as we can't use the use the ubuntu package pip version (20.0.2)
Expand All @@ -23,6 +31,7 @@ RUN apt-get remove -y python3-pip
RUN ln -s /usr/local/bin/pip3 /usr/bin/pip
RUN ln -s /usr/local/bin/pip3 /usr/bin/pip3

# install our dependencies (see setup.py)

# Install project dependencies (see setup.py)
COPY setup.py .
RUN pip install . --extra-index-url https://download.pytorch.org/whl/cu113
RUN pip install . --extra-index-url https://download.pytorch.org/whl/cu121
12 changes: 2 additions & 10 deletions azure/batch/scripts/predict_on_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,7 @@ def __getitem__(self, idx):
# ensure we raise other response errors like 404 and 500 etc
# Note: we don't retry on errors that aren't in the `status_forcelist`, instead we fast fail!
response.raise_for_status()
url_mime_type = response.headers['content-type']
# handle PNG images
if url_mime_type == 'image/png':
# use PIL image to read the png file buffer
image = Image.open(response.raw)
else: # but assume all other images are JPEG
# HWC PIL image
image = Image.fromarray(
galaxy_dataset.decode_jpeg(response.raw.read()))
image = Image.open(response.raw)
except Exception as e:
# add some logging on the failed url
logging.critical('Cannot load {}'.format(url))
Expand All @@ -98,7 +90,7 @@ def __getitem__(self, idx):
class PredictionGalaxyDataModule(galaxy_datamodule.GalaxyDataModule):
# override the setup method to setup our prediction dataset on the prediction catalog
def setup(self, stage: Optional[str] = None):
self.predict_dataset = PredictionGalaxyDataset(catalog=self.predict_catalog, transform=self.transform)
self.predict_dataset = PredictionGalaxyDataset(catalog=self.predict_catalog, transform=self.test_transform)


def save_predictions_to_json(predictions: np.ndarray, image_ids: List[str], label_cols: List[str], save_loc: str):
Expand Down
15 changes: 4 additions & 11 deletions azure/batch/scripts/train_model_finetune_on_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule

from zoobot.pytorch.training import finetune
from zoobot.shared.schemas import cosmic_dawn_ortho_schema
from zoobot.shared.schemas import cosmic_dawn_ortho_schema, euclid_ortho_schema

if __name__ == '__main__':

Expand Down Expand Up @@ -41,11 +41,7 @@

schema_dict = {
'cosmic_dawn': cosmic_dawn_ortho_schema,
'euclid': {
'label_cols': ['smooth-or-featured-euclid_smooth', 'smooth-or-featured-euclid_featured-or-disk', 'smooth-or-featured-euclid_problem', 'disk-edge-on-euclid_yes', 'disk-edge-on-euclid_no', 'has-spiral-arms-euclid_yes', 'has-spiral-arms-euclid_no', 'bar-euclid_strong', 'bar-euclid_weak', 'bar-euclid_no', 'bulge-size-euclid_dominant', 'bulge-size-euclid_large', 'bulge-size-euclid_moderate', 'bulge-size-euclid_small', 'bulge-size-euclid_none', 'how-rounded-euclid_round', 'how-rounded-euclid_in-between', 'how-rounded-euclid_cigar-shaped', 'edge-on-bulge-euclid_boxy', 'edge-on-bulge-euclid_none', 'edge-on-bulge-euclid_rounded', 'spiral-winding-euclid_tight', 'spiral-winding-euclid_medium', 'spiral-winding-euclid_loose', 'spiral-arm-count-euclid_1', 'spiral-arm-count-euclid_2', 'spiral-arm-count-euclid_3', 'spiral-arm-count-euclid_4', 'spiral-arm-count-euclid_more-than-4', 'spiral-arm-count-euclid_cant-tell', 'merging-euclid_none', 'merging-euclid_minor-disturbance', 'merging-euclid_major-disturbance', 'merging-euclid_merger', 'clumps-euclid_yes', 'clumps-euclid_no', 'problem-euclid_star', 'problem-euclid_artifact', 'problem-euclid_zoom', 'artifact-euclid_satellite', 'artifact-euclid_scattered', 'artifact-euclid_diffraction', 'artifact-euclid_ray', 'artifact-euclid_saturation', 'artifact-euclid_other', 'artifact-euclid_ghost'],
'questions': ['smooth-or-featured-euclid', 'indices 0 to 2', 'asked after None', 'disk-edge-on-euclid', 'indices 3 to 4', 'asked after smooth-or-featured-euclid_featured-or-disk', 'index 1', 'has-spiral-arms-euclid', 'indices 5 to 6', 'asked after disk-edge-on-euclid_no', 'index 4', 'bar-euclid', 'indices 7 to 9', 'asked after disk-edge-on-euclid_no', 'index 4', 'bulge-size-euclid', 'indices 10 to 14', 'asked after disk-edge-on-euclid_no', 'index 4', 'how-rounded-euclid', 'indices 15 to 17',' asked after smooth-or-featured-euclid_smooth', 'index 0', 'edge-on-bulge-euclid', 'indices 18 to 20', 'asked after disk-edge-on-euclid_yes', 'index 3', 'spiral-winding-euclid', 'indices 21 to 23', 'asked after has-spiral-arms-euclid_yes', 'index 5', 'spiral-arm-count-euclid', 'indices 24 to 29', 'asked after has-spiral-arms-euclid_yes', 'index 5', 'merging-euclid', 'indices 30 to 33', 'asked after None', 'clumps-euclid', 'indices 34 to 35', 'asked after disk-edge-on-euclid_no', 'index 4', 'problem-euclid', 'indices 36 to 38', 'asked after smooth-or-featured-euclid_problem', 'index 2', 'artifact-euclid', 'indices 39 to 45', 'asked after problem-euclid_artifact', 'index 37'],
'question_answer_pairs': {'smooth-or-featured-euclid': ['_smooth', '_featured-or-disk', '_problem'], 'disk-edge-on-euclid': ['_yes', '_no'], 'has-spiral-arms-euclid': ['_yes', '_no'], 'bar-euclid': ['_strong', '_weak', '_no'], 'bulge-size-euclid': ['_dominant', '_large', '_moderate', '_small', '_none'], 'how-rounded-euclid': ['_round', '_in-between', '_cigar-shaped'], 'edge-on-bulge-euclid': ['_boxy', '_none', '_rounded'], 'spiral-winding-euclid': ['_tight', '_medium', '_loose'], 'spiral-arm-count-euclid': ['_1', '_2', '_3', '_4', '_more-than-4', '_cant-tell'], 'merging-euclid': ['_none', '_minor-disturbance', '_major-disturbance', '_merger'], 'clumps-euclid': ['_yes', '_no'], 'problem-euclid': ['_star', '_artifact', '_zoom'], 'artifact-euclid': ['_satellite', '_scattered', '_diffraction', '_ray', '_saturation', '_other', '_ghost']}
}
'euclid': euclid_ortho_schema
}
schema = schema_dict.get(args.schema, cosmic_dawn_ortho_schema)
# setup the error reporting tool - https://app.honeybadger.io/projects/
Expand Down Expand Up @@ -105,15 +101,12 @@
else:
logger = None


# load the model from checkpoint
model = finetune.FinetuneableZoobotTree(
checkpoint_loc=args.checkpoint,
# params specific to tree finetuning
schema=schema,
# params for superclass i.e. any finetuning
encoder_dim=args.encoder_dim,
n_layers=args.n_layers,
prog_bar=args.progress_bar
zoobot_checkpoint_loc=args.checkpoint
)

trainer = finetune.get_trainer(
Expand Down
4 changes: 2 additions & 2 deletions azure/batch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
"Environment :: GPU :: NVIDIA CUDA"
],
packages=setuptools.find_packages(),
python_requires=">=3.7", # tf 2.8.0 requires Python 3.7 and above
python_requires=">=3.9", # tf 2.8.0 requires Python 3.7 and above
install_requires=[
'zoobot[pytorch_cu113] >= 1.0', # the big cheese - bring in the zoobot!
'zoobot[pytorch-cu121] >= 2.0.0', # the big cheese - bring in the zoobot!
'requests >= 2.28.1', # used to download prediction images from a remote URL
'honeybadger' # used for error reporting
]
Expand Down
Loading