Skip to content

Commit

Permalink
cosem_starter
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 7, 2024
1 parent dbf12fe commit bfd1f72
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 66 deletions.
102 changes: 102 additions & 0 deletions docs/source/cosem_starter.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@

Fine-Tune Cosem Starter
============================

The CosemStarter in DaCapo allows you to load a pretrained COSEM model and fine-tune it for your experiments. This guide explains how to set up and use CosemStarter in DaCapo.

Prerequisites
-------------

Ensure that you have DaCapo installed and configured correctly.

Step 1: Import the CosemStartConfig
-----------------------------------

To get started, you need to import `CosemStartConfig` from `dacapo.experiments.starts`.

.. code-block:: python
from dacapo.experiments.starts import CosemStartConfig
Step 2: Configure the Start Model
---------------------------------

The `CosemStartConfig` takes two parameters:

- **model_name**: The name of the model setup to load.
- **checkpoint**: The specific checkpoint ID to load the pretrained model from.

Example:

.. code-block:: python
# We will now download a pretrained COSEM model and fine-tune from that model.
# It will only download the model the first time it is used.
start_config = CosemStartConfig("setup04", "1820500")
This configuration will download the COSEM model from setup `setup04` and load the checkpoint `1820500`. You only need to download the model once; subsequent runs will use the downloaded model.

Step 3: Create a Run with `start_config`
----------------------------------------

To start from the pretrained model, add `start_config` to your `RunConfig`. The `RunConfig` initializes the run and allows fine-tuning from the pretrained COSEM model.

Example:

.. code-block:: python
from dacapo.experiments.runs import RunConfig
run_config = RunConfig(
# other parameters...
start_config=start_config,
)
Full Example
------------

Here’s how the complete setup looks:

.. code-block:: python
from dacapo.experiments.starts import CosemStartConfig
from dacapo.experiments.runs import RunConfig
# Define the start configuration to load the pretrained COSEM model
start_config = CosemStartConfig("setup04", "1820500")
# Define the run configuration with the start configuration
run_config = RunConfig(
# other configurations,
start_config=start_config,
)
# Now you can run this configuration in your experiment to start from the COSEM pretrained model
This setup will initiate your DaCapo run from the pretrained COSEM model and allow you to fine-tune it as needed.

Available COSEM Pretrained Models
---------------------------------

Below is a table of the COSEM pretrained models available, along with their details:

+-----------+----------------------------+-----------------+--------------------------------------------------------------+-----------+------------+-----------------+
| Model | Checkpoints | Best Checkpoint| Classes | Input Res | Output Res | Model |
+===========+============================+=================+==============================================================+===========+============+=================+
| setup04 | 975000, 625000, 1820500 | 1820500 | ecs, pm, mito, mito_mem, ves, ves_mem, endo, endo_mem, er, er_mem, eres, nuc, mt, mt_out | 8 nm | 4 nm | Upsample U-Net |
+-----------+----------------------------+-----------------+--------------------------------------------------------------+-----------+------------+-----------------+
| setup26.1 | 650000, 2580000 | 2580000 | mito, mito_mem, mito_ribo | 8 nm | 4 nm | Upsample U-Net |
+-----------+----------------------------+-----------------+--------------------------------------------------------------+-----------+------------+-----------------+
| setup28 | 775000 | 775000 | er, er_mem | 8 nm | 4 nm | Upsample U-Net |
+-----------+----------------------------+-----------------+--------------------------------------------------------------+-----------+------------+-----------------+
| setup36 | 500000, 1100000 | 1100000 | nuc, nucleo | 8 nm | 4 nm | Upsample U-Net |
+-----------+----------------------------+-----------------+--------------------------------------------------------------+-----------+------------+-----------------+
| setup45 | 625000, 1634500 | 1634500 | ecs, pm | 4 nm | 4 nm | U-Net |
+-----------+----------------------------+-----------------+--------------------------------------------------------------+-----------+------------+-----------------+

Notes
-----

- The model will download only the first time you use it. After that, it will reuse the downloaded version.
- Ensure that you have the necessary storage and access permissions configured for the COSEM model files.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
tutorial
docker
aws
cosem_starter
autoapi/index
cli

Expand Down
179 changes: 113 additions & 66 deletions examples/starter_tutorial/minimal_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,23 +172,28 @@
"\n",
"# Create the zarr array with appropriate metadata\n",
"cell_array = prepare_ds(\n",
" \"cells3d.zarr\",\n",
" \"raw\",\n",
" Roi((0, 0, 0), cell_data.shape[1:]) * voxel_size,\n",
" \"cells3d.zarr/raw\",\n",
" cell_data.shape,\n",
" offset=offset,\n",
" voxel_size=voxel_size,\n",
" axis_names=axis_names,\n",
" units=units,\n",
" mode=\"w\",\n",
" dtype=np.uint8,\n",
" num_channels=None,\n",
")\n",
"\n",
"# Save the cell data to the zarr array\n",
"cell_array[cell_array.roi] = cell_data[1]\n",
"cell_array[cell_array.roi] = cell_data\n",
"\n",
"# Generate and save some pseudo ground truth data\n",
"mask_array = prepare_ds(\n",
" \"cells3d.zarr\",\n",
" \"mask\",\n",
" Roi((0, 0, 0), cell_data.shape[1:]) * voxel_size,\n",
" \"cells3d.zarr/mask\",\n",
" cell_data.shape[1:],\n",
" offset=offset,\n",
" voxel_size=voxel_size,\n",
" axis_names=axis_names[1:],\n",
" units=units,\n",
" mode=\"w\",\n",
" dtype=np.uint8,\n",
")\n",
"cell_mask = np.clip(gaussian(cell_data[1] / 255.0, sigma=1), 0, 255) * 255 > 30\n",
Expand All @@ -197,28 +202,36 @@
"\n",
"# Generate labels via connected components\n",
"labels_array = prepare_ds(\n",
" \"cells3d.zarr\",\n",
" \"labels\",\n",
" Roi((0, 0, 0), cell_data.shape[1:]) * voxel_size,\n",
" \"cells3d.zarr/labels\",\n",
" cell_data.shape[1:],\n",
" offset=offset,\n",
" voxel_size=voxel_size,\n",
" axis_names=axis_names[1:],\n",
" units=units,\n",
" mode=\"w\",\n",
" dtype=np.uint8,\n",
")\n",
"labels_array[labels_array.roi] = label(mask_array.to_ndarray(mask_array.roi))[0]\n",
"\n",
"print(\"Data saved to cells3d.zarr\")\n",
"import zarr\n",
"\n",
"print(zarr.open(\"cells3d.zarr\", mode=\"r\").tree())\n",
"# %% [markdown]\n",
"# Here we show a slice of the raw data:\n",
"# %%\n",
"# a custom label color map for showing instances\n",
"print(zarr.open(\"cells3d.zarr\").tree())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n",
"\n",
"# Show the raw data\n",
"axes[0].imshow(cell_array.data[30])\n",
"axes[0].imshow(cell_array.data[0, 30])\n",
"axes[0].set_title(\"Raw Data\")\n",
"\n",
"# Show the labels using the custom label color map\n",
Expand Down Expand Up @@ -248,26 +261,59 @@
"metadata": {},
"outputs": [],
"source": [
"from dacapo.experiments.datasplits import DataSplitGenerator, DatasetSpec\n",
"\n",
"dataspecs = [\n",
" DatasetSpec(\n",
" dataset_type=type_crop,\n",
" raw_container=\"cells3d.zarr\",\n",
" raw_dataset=\"raw\",\n",
" gt_container=\"cells3d.zarr\",\n",
" gt_dataset=\"labels\",\n",
" )\n",
" for type_crop in [\"train\", \"val\"]\n",
"]\n",
"\n",
"datasplit_config = DataSplitGenerator(\n",
" name=\"skimage_tutorial_data\",\n",
" datasets=dataspecs,\n",
" input_resolution=voxel_size,\n",
" output_resolution=voxel_size,\n",
" targets=[\"cell\"],\n",
").compute()\n"
"from dacapo.experiments.datasplits import TrainValidateDataSplitConfig\n",
"from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig\n",
"from dacapo.experiments.datasplits.datasets.arrays import (\n",
" ZarrArrayConfig,\n",
" IntensitiesArrayConfig,\n",
")\n",
"from funlib.geometry import Coordinate\n",
"\n",
"datasplit_config = TrainValidateDataSplitConfig(\n",
" name=\"example_datasplit\",\n",
" train_configs=[\n",
" RawGTDatasetConfig(\n",
" name=\"example_dataset\",\n",
" raw_config=IntensitiesArrayConfig(\n",
" name=\"example_raw_normalized\",\n",
" source_array_config=ZarrArrayConfig(\n",
" name=\"example_raw\",\n",
" file_name=\"cells3d.zarr\",\n",
" dataset=\"raw\",\n",
" ),\n",
" min=0,\n",
" max=255,\n",
" ),\n",
" gt_config=ZarrArrayConfig(\n",
" name=\"example_gt\",\n",
" file_name=\"cells3d.zarr\",\n",
" dataset=\"mask\",\n",
" ),\n",
" )\n",
" ],\n",
" validate_configs=[\n",
" RawGTDatasetConfig(\n",
" name=\"example_dataset\",\n",
" raw_config=IntensitiesArrayConfig(\n",
" name=\"example_raw_normalized\",\n",
" source_array_config=ZarrArrayConfig(\n",
" name=\"example_raw\",\n",
" file_name=\"cells3d.zarr\",\n",
" dataset=\"raw\",\n",
" ),\n",
" min=0,\n",
" max=255,\n",
" ),\n",
" gt_config=ZarrArrayConfig(\n",
" name=\"example_gt\",\n",
" file_name=\"cells3d.zarr\",\n",
" dataset=\"labels\",\n",
" ),\n",
" )\n",
" ],\n",
")\n",
"datasplit = datasplit_config.datasplit_type(datasplit_config)\n",
"config_store.store_datasplit_config(datasplit_config)"
]
},
{
Expand Down Expand Up @@ -360,7 +406,7 @@
" name=\"example_unet\",\n",
" input_shape=(2, 132, 132),\n",
" eval_shape_increase=(8, 32, 32),\n",
" fmaps_in=1,\n",
" fmaps_in=2,\n",
" num_fmaps=8,\n",
" fmaps_out=8,\n",
" fmap_inc_factor=2,\n",
Expand All @@ -370,7 +416,7 @@
" constant_upsample=True,\n",
" padding=\"valid\",\n",
")\n",
"config_store.store_architecture_config(architecture_config)\n"
"config_store.store_architecture_config(architecture_config)"
]
},
{
Expand All @@ -397,7 +443,7 @@
" name=\"example\",\n",
" batch_size=10,\n",
" learning_rate=0.0001,\n",
" num_data_fetchers=8,\n",
" num_data_fetchers=1,\n",
" snapshot_interval=1000,\n",
" min_masked=0.05,\n",
" clip_raw=False,\n",
Expand Down Expand Up @@ -549,7 +595,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"import zarr\n",
"from matplotlib.colors import ListedColormap\n",
"\n",
Expand All @@ -559,28 +604,31 @@
"\n",
"run_path = config_store.path.parent / run_config.name\n",
"\n",
"# BROWSER = False\n",
"num_snapshots = run_config.num_iterations // run_config.trainer_config.snapshot_interval\n",
"fig, ax = plt.subplots(num_snapshots, 3, figsize=(10, 2 * num_snapshots))\n",
"\n",
"# Set column titles\n",
"column_titles = [\"Raw\", \"Target\", \"Prediction\"]\n",
"for col in range(3):\n",
" ax[0, col].set_title(column_titles[col])\n",
"if num_snapshots > 0:\n",
" fig, ax = plt.subplots(num_snapshots, 3, figsize=(10, 2 * num_snapshots))\n",
"\n",
"for snapshot in range(num_snapshots):\n",
" snapshot_it = snapshot * run_config.trainer_config.snapshot_interval\n",
" # break\n",
" raw = zarr.open(f\"{run_path}/snapshot.zarr/{snapshot_it}/volumes/raw\")[:]\n",
" target = zarr.open(f\"{run_path}/snapshot.zarr/{snapshot_it}/volumes/target\")[0]\n",
" prediction = zarr.open(\n",
" f\"{run_path}/snapshot.zarr/{snapshot_it}/volumes/prediction\"\n",
" )[0]\n",
" c = (raw.shape[1] - target.shape[1]) // 2\n",
" ax[snapshot, 0].imshow(raw[raw.shape[0] // 2, c:-c, c:-c])\n",
" ax[snapshot, 1].imshow(target[target.shape[0] // 2])\n",
" ax[snapshot, 2].imshow(prediction[prediction.shape[0] // 2])\n",
" ax[snapshot, 0].set_ylabel(f\"Snapshot {snapshot_it}\")\n",
"plt.show()"
" # Set column titles\n",
" column_titles = [\"Raw\", \"Target\", \"Prediction\"]\n",
" for col in range(3):\n",
" ax[0, col].set_title(column_titles[col])\n",
"\n",
" for snapshot in range(num_snapshots):\n",
" snapshot_it = snapshot * run_config.trainer_config.snapshot_interval\n",
" # break\n",
" raw = zarr.open(f\"{run_path}/snapshot.zarr/{snapshot_it}/volumes/raw\")[:]\n",
" target = zarr.open(f\"{run_path}/snapshot.zarr/{snapshot_it}/volumes/target\")[0]\n",
" prediction = zarr.open(\n",
" f\"{run_path}/snapshot.zarr/{snapshot_it}/volumes/prediction\"\n",
" )[0]\n",
" c = (raw.shape[2] - target.shape[1]) // 2\n",
" ax[snapshot, 0].imshow(raw[1, raw.shape[0] // 2, c:-c, c:-c])\n",
" ax[snapshot, 1].imshow(target[target.shape[0] // 2])\n",
" ax[snapshot, 2].imshow(prediction[prediction.shape[0] // 2])\n",
" ax[snapshot, 0].set_ylabel(f\"Snapshot {snapshot_it}\")\n",
" plt.show()"
]
},
{
Expand All @@ -589,7 +637,6 @@
"metadata": {},
"outputs": [],
"source": [
"# Visualize validations\n",
"import zarr\n",
"\n",
"num_validations = run_config.num_iterations // run_config.validation_interval\n",
Expand All @@ -604,16 +651,16 @@
" dataset = run.datasplit.validate[0].name\n",
" validation_it = validation * run_config.validation_interval\n",
" # break\n",
" raw = zarr.open(f\"{run_path}/validation.zarr/inputs/{dataset}/raw\")[:]\n",
" gt = zarr.open(f\"{run_path}/validation.zarr/inputs/{dataset}/gt\")[0]\n",
" raw = zarr.open(f\"{run_path}/validation.zarr/inputs/{dataset}/raw\")\n",
" gt = zarr.open(f\"{run_path}/validation.zarr/inputs/{dataset}/gt\")\n",
" pred_path = f\"{run_path}/validation.zarr/{validation_it}/ds_{dataset}/prediction\"\n",
" out_path = f\"{run_path}/validation.zarr/{validation_it}/ds_{dataset}/output/WatershedPostProcessorParameters(id=2, bias=0.5, context=(32, 32, 32))\"\n",
" output = zarr.open(out_path)[:]\n",
" prediction = zarr.open(pred_path)[0]\n",
" c = (raw.shape[1] - gt.shape[1]) // 2\n",
" c = (raw.shape[2] - gt.shape[1]) // 2\n",
" if c != 0:\n",
" raw = raw[:, c:-c, c:-c]\n",
" ax[validation - 1, 0].imshow(raw[raw.shape[0] // 2])\n",
" raw = raw[:, :, c:-c, c:-c]\n",
" ax[validation - 1, 0].imshow(raw[1, raw.shape[1] // 2])\n",
" ax[validation - 1, 1].imshow(\n",
" gt[gt.shape[0] // 2], cmap=label_cmap, interpolation=\"none\"\n",
" )\n",
Expand Down

0 comments on commit bfd1f72

Please sign in to comment.