diff --git a/docs/source/cosem_starter.rst b/docs/source/cosem_starter.rst new file mode 100644 index 000000000..e7d8da980 --- /dev/null +++ b/docs/source/cosem_starter.rst @@ -0,0 +1,102 @@ + +Fine-Tune Cosem Starter +============================ + +The CosemStarter in DaCapo allows you to load a pretrained COSEM model and fine-tune it for your experiments. This guide explains how to set up and use CosemStarter in DaCapo. + +Prerequisites +------------- + +Ensure that you have DaCapo installed and configured correctly. + +Step 1: Import the CosemStartConfig +----------------------------------- + +To get started, you need to import `CosemStartConfig` from `dacapo.experiments.starts`. + +.. code-block:: python + + from dacapo.experiments.starts import CosemStartConfig + +Step 2: Configure the Start Model +--------------------------------- + +The `CosemStartConfig` takes two parameters: + +- **model_name**: The name of the model setup to load. +- **checkpoint**: The specific checkpoint ID to load the pretrained model from. + +Example: + +.. code-block:: python + + # We will now download a pretrained COSEM model and fine-tune from that model. + # It will only download the model the first time it is used. + + start_config = CosemStartConfig("setup04", "1820500") + +This configuration will download the COSEM model from setup `setup04` and load the checkpoint `1820500`. You only need to download the model once; subsequent runs will use the downloaded model. + +Step 3: Create a Run with `start_config` +---------------------------------------- + +To start from the pretrained model, add `start_config` to your `RunConfig`. The `RunConfig` initializes the run and allows fine-tuning from the pretrained COSEM model. + +Example: + +.. code-block:: python + + from dacapo.experiments.runs import RunConfig + + run_config = RunConfig( + # other parameters... + start_config=start_config, + ) + +Full Example +------------ + +Here’s how the complete setup looks: + +.. code-block:: python + + from dacapo.experiments.starts import CosemStartConfig + from dacapo.experiments.runs import RunConfig + + # Define the start configuration to load the pretrained COSEM model + start_config = CosemStartConfig("setup04", "1820500") + + # Define the run configuration with the start configuration + run_config = RunConfig( + # other configurations, + start_config=start_config, + ) + + # Now you can run this configuration in your experiment to start from the COSEM pretrained model + +This setup will initiate your DaCapo run from the pretrained COSEM model and allow you to fine-tune it as needed. + +Available COSEM Pretrained Models +--------------------------------- + +Below is a table of the COSEM pretrained models available, along with their details: + ++-----------+----------------------------+-----------------+--------------------------------------------------------------+-----------+------------+-----------------+ +| Model | Checkpoints | Best Checkpoint| Classes | Input Res | Output Res | Model | ++===========+============================+=================+==============================================================+===========+============+=================+ +| setup04 | 975000, 625000, 1820500 | 1820500 | ecs, pm, mito, mito_mem, ves, ves_mem, endo, endo_mem, er, er_mem, eres, nuc, mt, mt_out | 8 nm | 4 nm | Upsample U-Net | ++-----------+----------------------------+-----------------+--------------------------------------------------------------+-----------+------------+-----------------+ +| setup26.1 | 650000, 2580000 | 2580000 | mito, mito_mem, mito_ribo | 8 nm | 4 nm | Upsample U-Net | ++-----------+----------------------------+-----------------+--------------------------------------------------------------+-----------+------------+-----------------+ +| setup28 | 775000 | 775000 | er, er_mem | 8 nm | 4 nm | Upsample U-Net | ++-----------+----------------------------+-----------------+--------------------------------------------------------------+-----------+------------+-----------------+ +| setup36 | 500000, 1100000 | 1100000 | nuc, nucleo | 8 nm | 4 nm | Upsample U-Net | ++-----------+----------------------------+-----------------+--------------------------------------------------------------+-----------+------------+-----------------+ +| setup45 | 625000, 1634500 | 1634500 | ecs, pm | 4 nm | 4 nm | U-Net | ++-----------+----------------------------+-----------------+--------------------------------------------------------------+-----------+------------+-----------------+ + +Notes +----- + +- The model will download only the first time you use it. After that, it will reuse the downloaded version. +- Ensure that you have the necessary storage and access permissions configured for the COSEM model files. diff --git a/docs/source/index.rst b/docs/source/index.rst index e08339a02..3a94253db 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -13,6 +13,7 @@ tutorial docker aws + cosem_starter autoapi/index cli diff --git a/examples/starter_tutorial/minimal_tutorial.ipynb b/examples/starter_tutorial/minimal_tutorial.ipynb index 2b9f2202a..3fbcc0d4e 100644 --- a/examples/starter_tutorial/minimal_tutorial.ipynb +++ b/examples/starter_tutorial/minimal_tutorial.ipynb @@ -172,23 +172,28 @@ "\n", "# Create the zarr array with appropriate metadata\n", "cell_array = prepare_ds(\n", - " \"cells3d.zarr\",\n", - " \"raw\",\n", - " Roi((0, 0, 0), cell_data.shape[1:]) * voxel_size,\n", + " \"cells3d.zarr/raw\",\n", + " cell_data.shape,\n", + " offset=offset,\n", " voxel_size=voxel_size,\n", + " axis_names=axis_names,\n", + " units=units,\n", + " mode=\"w\",\n", " dtype=np.uint8,\n", - " num_channels=None,\n", ")\n", "\n", "# Save the cell data to the zarr array\n", - "cell_array[cell_array.roi] = cell_data[1]\n", + "cell_array[cell_array.roi] = cell_data\n", "\n", "# Generate and save some pseudo ground truth data\n", "mask_array = prepare_ds(\n", - " \"cells3d.zarr\",\n", - " \"mask\",\n", - " Roi((0, 0, 0), cell_data.shape[1:]) * voxel_size,\n", + " \"cells3d.zarr/mask\",\n", + " cell_data.shape[1:],\n", + " offset=offset,\n", " voxel_size=voxel_size,\n", + " axis_names=axis_names[1:],\n", + " units=units,\n", + " mode=\"w\",\n", " dtype=np.uint8,\n", ")\n", "cell_mask = np.clip(gaussian(cell_data[1] / 255.0, sigma=1), 0, 255) * 255 > 30\n", @@ -197,10 +202,13 @@ "\n", "# Generate labels via connected components\n", "labels_array = prepare_ds(\n", - " \"cells3d.zarr\",\n", - " \"labels\",\n", - " Roi((0, 0, 0), cell_data.shape[1:]) * voxel_size,\n", + " \"cells3d.zarr/labels\",\n", + " cell_data.shape[1:],\n", + " offset=offset,\n", " voxel_size=voxel_size,\n", + " axis_names=axis_names[1:],\n", + " units=units,\n", + " mode=\"w\",\n", " dtype=np.uint8,\n", ")\n", "labels_array[labels_array.roi] = label(mask_array.to_ndarray(mask_array.roi))[0]\n", @@ -208,17 +216,22 @@ "print(\"Data saved to cells3d.zarr\")\n", "import zarr\n", "\n", - "print(zarr.open(\"cells3d.zarr\", mode=\"r\").tree())\n", - "# %% [markdown]\n", - "# Here we show a slice of the raw data:\n", - "# %%\n", - "# a custom label color map for showing instances\n", + "print(zarr.open(\"cells3d.zarr\").tree())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "import matplotlib.pyplot as plt\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n", "\n", "# Show the raw data\n", - "axes[0].imshow(cell_array.data[30])\n", + "axes[0].imshow(cell_array.data[0, 30])\n", "axes[0].set_title(\"Raw Data\")\n", "\n", "# Show the labels using the custom label color map\n", @@ -248,26 +261,59 @@ "metadata": {}, "outputs": [], "source": [ - "from dacapo.experiments.datasplits import DataSplitGenerator, DatasetSpec\n", - "\n", - "dataspecs = [\n", - " DatasetSpec(\n", - " dataset_type=type_crop,\n", - " raw_container=\"cells3d.zarr\",\n", - " raw_dataset=\"raw\",\n", - " gt_container=\"cells3d.zarr\",\n", - " gt_dataset=\"labels\",\n", - " )\n", - " for type_crop in [\"train\", \"val\"]\n", - "]\n", - "\n", - "datasplit_config = DataSplitGenerator(\n", - " name=\"skimage_tutorial_data\",\n", - " datasets=dataspecs,\n", - " input_resolution=voxel_size,\n", - " output_resolution=voxel_size,\n", - " targets=[\"cell\"],\n", - ").compute()\n" + "from dacapo.experiments.datasplits import TrainValidateDataSplitConfig\n", + "from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig\n", + "from dacapo.experiments.datasplits.datasets.arrays import (\n", + " ZarrArrayConfig,\n", + " IntensitiesArrayConfig,\n", + ")\n", + "from funlib.geometry import Coordinate\n", + "\n", + "datasplit_config = TrainValidateDataSplitConfig(\n", + " name=\"example_datasplit\",\n", + " train_configs=[\n", + " RawGTDatasetConfig(\n", + " name=\"example_dataset\",\n", + " raw_config=IntensitiesArrayConfig(\n", + " name=\"example_raw_normalized\",\n", + " source_array_config=ZarrArrayConfig(\n", + " name=\"example_raw\",\n", + " file_name=\"cells3d.zarr\",\n", + " dataset=\"raw\",\n", + " ),\n", + " min=0,\n", + " max=255,\n", + " ),\n", + " gt_config=ZarrArrayConfig(\n", + " name=\"example_gt\",\n", + " file_name=\"cells3d.zarr\",\n", + " dataset=\"mask\",\n", + " ),\n", + " )\n", + " ],\n", + " validate_configs=[\n", + " RawGTDatasetConfig(\n", + " name=\"example_dataset\",\n", + " raw_config=IntensitiesArrayConfig(\n", + " name=\"example_raw_normalized\",\n", + " source_array_config=ZarrArrayConfig(\n", + " name=\"example_raw\",\n", + " file_name=\"cells3d.zarr\",\n", + " dataset=\"raw\",\n", + " ),\n", + " min=0,\n", + " max=255,\n", + " ),\n", + " gt_config=ZarrArrayConfig(\n", + " name=\"example_gt\",\n", + " file_name=\"cells3d.zarr\",\n", + " dataset=\"labels\",\n", + " ),\n", + " )\n", + " ],\n", + ")\n", + "datasplit = datasplit_config.datasplit_type(datasplit_config)\n", + "config_store.store_datasplit_config(datasplit_config)" ] }, { @@ -360,7 +406,7 @@ " name=\"example_unet\",\n", " input_shape=(2, 132, 132),\n", " eval_shape_increase=(8, 32, 32),\n", - " fmaps_in=1,\n", + " fmaps_in=2,\n", " num_fmaps=8,\n", " fmaps_out=8,\n", " fmap_inc_factor=2,\n", @@ -370,7 +416,7 @@ " constant_upsample=True,\n", " padding=\"valid\",\n", ")\n", - "config_store.store_architecture_config(architecture_config)\n" + "config_store.store_architecture_config(architecture_config)" ] }, { @@ -397,7 +443,7 @@ " name=\"example\",\n", " batch_size=10,\n", " learning_rate=0.0001,\n", - " num_data_fetchers=8,\n", + " num_data_fetchers=1,\n", " snapshot_interval=1000,\n", " min_masked=0.05,\n", " clip_raw=False,\n", @@ -549,7 +595,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "import zarr\n", "from matplotlib.colors import ListedColormap\n", "\n", @@ -559,28 +604,31 @@ "\n", "run_path = config_store.path.parent / run_config.name\n", "\n", + "# BROWSER = False\n", "num_snapshots = run_config.num_iterations // run_config.trainer_config.snapshot_interval\n", - "fig, ax = plt.subplots(num_snapshots, 3, figsize=(10, 2 * num_snapshots))\n", "\n", - "# Set column titles\n", - "column_titles = [\"Raw\", \"Target\", \"Prediction\"]\n", - "for col in range(3):\n", - " ax[0, col].set_title(column_titles[col])\n", + "if num_snapshots > 0:\n", + " fig, ax = plt.subplots(num_snapshots, 3, figsize=(10, 2 * num_snapshots))\n", "\n", - "for snapshot in range(num_snapshots):\n", - " snapshot_it = snapshot * run_config.trainer_config.snapshot_interval\n", - " # break\n", - " raw = zarr.open(f\"{run_path}/snapshot.zarr/{snapshot_it}/volumes/raw\")[:]\n", - " target = zarr.open(f\"{run_path}/snapshot.zarr/{snapshot_it}/volumes/target\")[0]\n", - " prediction = zarr.open(\n", - " f\"{run_path}/snapshot.zarr/{snapshot_it}/volumes/prediction\"\n", - " )[0]\n", - " c = (raw.shape[1] - target.shape[1]) // 2\n", - " ax[snapshot, 0].imshow(raw[raw.shape[0] // 2, c:-c, c:-c])\n", - " ax[snapshot, 1].imshow(target[target.shape[0] // 2])\n", - " ax[snapshot, 2].imshow(prediction[prediction.shape[0] // 2])\n", - " ax[snapshot, 0].set_ylabel(f\"Snapshot {snapshot_it}\")\n", - "plt.show()" + " # Set column titles\n", + " column_titles = [\"Raw\", \"Target\", \"Prediction\"]\n", + " for col in range(3):\n", + " ax[0, col].set_title(column_titles[col])\n", + "\n", + " for snapshot in range(num_snapshots):\n", + " snapshot_it = snapshot * run_config.trainer_config.snapshot_interval\n", + " # break\n", + " raw = zarr.open(f\"{run_path}/snapshot.zarr/{snapshot_it}/volumes/raw\")[:]\n", + " target = zarr.open(f\"{run_path}/snapshot.zarr/{snapshot_it}/volumes/target\")[0]\n", + " prediction = zarr.open(\n", + " f\"{run_path}/snapshot.zarr/{snapshot_it}/volumes/prediction\"\n", + " )[0]\n", + " c = (raw.shape[2] - target.shape[1]) // 2\n", + " ax[snapshot, 0].imshow(raw[1, raw.shape[0] // 2, c:-c, c:-c])\n", + " ax[snapshot, 1].imshow(target[target.shape[0] // 2])\n", + " ax[snapshot, 2].imshow(prediction[prediction.shape[0] // 2])\n", + " ax[snapshot, 0].set_ylabel(f\"Snapshot {snapshot_it}\")\n", + " plt.show()" ] }, { @@ -589,7 +637,6 @@ "metadata": {}, "outputs": [], "source": [ - "# Visualize validations\n", "import zarr\n", "\n", "num_validations = run_config.num_iterations // run_config.validation_interval\n", @@ -604,16 +651,16 @@ " dataset = run.datasplit.validate[0].name\n", " validation_it = validation * run_config.validation_interval\n", " # break\n", - " raw = zarr.open(f\"{run_path}/validation.zarr/inputs/{dataset}/raw\")[:]\n", - " gt = zarr.open(f\"{run_path}/validation.zarr/inputs/{dataset}/gt\")[0]\n", + " raw = zarr.open(f\"{run_path}/validation.zarr/inputs/{dataset}/raw\")\n", + " gt = zarr.open(f\"{run_path}/validation.zarr/inputs/{dataset}/gt\")\n", " pred_path = f\"{run_path}/validation.zarr/{validation_it}/ds_{dataset}/prediction\"\n", " out_path = f\"{run_path}/validation.zarr/{validation_it}/ds_{dataset}/output/WatershedPostProcessorParameters(id=2, bias=0.5, context=(32, 32, 32))\"\n", " output = zarr.open(out_path)[:]\n", " prediction = zarr.open(pred_path)[0]\n", - " c = (raw.shape[1] - gt.shape[1]) // 2\n", + " c = (raw.shape[2] - gt.shape[1]) // 2\n", " if c != 0:\n", - " raw = raw[:, c:-c, c:-c]\n", - " ax[validation - 1, 0].imshow(raw[raw.shape[0] // 2])\n", + " raw = raw[:, :, c:-c, c:-c]\n", + " ax[validation - 1, 0].imshow(raw[1, raw.shape[1] // 2])\n", " ax[validation - 1, 1].imshow(\n", " gt[gt.shape[0] // 2], cmap=label_cmap, interpolation=\"none\"\n", " )\n",