Skip to content

Commit

Permalink
Merge pull request #136 from chromatix-team/dev
Browse files Browse the repository at this point in the history
Release Chromatix 0.3.0
  • Loading branch information
diptodip authored Jun 10, 2024
2 parents 7304cd3 + e0e9de9 commit 432772b
Show file tree
Hide file tree
Showing 68 changed files with 2,832 additions and 431 deletions.
12 changes: 0 additions & 12 deletions .github/workflows/black.yaml

This file was deleted.

22 changes: 22 additions & 0 deletions .github/workflows/format_lint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Ruff
on: [push, pull_request]

jobs:
build:

runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'pip'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.4.8
- name: Lint with ruff
run: ruff format .
- name: Format with ruff
run: ruff --output-format=github --fix .
12 changes: 4 additions & 8 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Python package
name: Tests

on: [push]

Expand All @@ -8,23 +8,19 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest ruff
pip install pytest
pip install .
- name: Test with ruff
run: |
ruff src/
- name: Test with pytest
run: |
pytest tests/
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<img alt="Chromatix logo" src="https://github.com/chromatix-team/chromatix/blob/main/docs/media/logo_text_black.png?raw=true">
</picture>

![CI](https://github.com/chromatix-team/chromatix/actions/workflows/test.yaml/badge.svg) ![Black](https://github.com/chromatix-team/chromatix/actions/workflows/black.yaml/badge.svg)
![CI](https://github.com/chromatix-team/chromatix/actions/workflows/test.yaml/badge.svg) ![Ruff](https://github.com/chromatix-team/chromatix/actions/workflows/format_lint.yaml/badge.svg)

[**Installation**](#installation)
| [**Usage**](#usage)
Expand Down Expand Up @@ -90,7 +90,7 @@ Chromatix was started by Diptodip Deb ([@diptodip](https://www.github.com/diptod

* Amey Chaware ([@isildur7](https://www.github.com/isildur7))
* Amit Kohli ([@apsk14](https://www.github.com/apsk14))
* Cédric Allier
* Cédric Allier ([@allierc](https://github.com/allierc))
* Changjia Cai ([@caichangjia](https://github.com/caichangjia))
* Geneva Schlafly ([@gschlafly](https://github.com/gschlafly))
* Guanghan Meng ([@guanghanmeng](https://github.com/guanghanmeng))
Expand Down
4 changes: 2 additions & 2 deletions docs/101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
"import chromatix.functional as cx\n",
"from chromatix import Field, ScalarField\n",
"from chromatix.systems import OpticalSystem, Microscope, Optical4FSystemPSF\n",
"from chromatix.elements import PlaneWave, FFLens, ThinSample, BasicSensor, ZernikeAberrations\n",
"from chromatix.utils import siemens_star, trainable"
"from chromatix.elements import PlaneWave, FFLens, ThinSample, BasicSensor, ZernikeAberrations, trainable\n",
"from chromatix.utils import siemens_star"
]
},
{
Expand Down
16 changes: 12 additions & 4 deletions docs/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,22 @@ Chromatix tries to respect composable `jax` transformations, so you can use all
We discuss these styles of parallelism in our documentation on [Parallelism](parallelism.md).

## How do I decide which parameters get optimized?
Any attribute of an element that is specified as a possibly trainable parameter can be initialized using `chromatix.utils.trainable` in order to make it trainable. Otherwise, the attribute will be initialized (using either an `Array`, `float`, or `Callable` that takes a shape argument as specified in the documentation for that function) as non-trainable state of that element. If you are initializing an attribute as trainable using an initialization function, then you can specify whether that function requires a `jax.random.PRNGKey` or not. For example, if you are initializing the pixels of a phase mask with the `flat_phase` function, then you can use `trainable(flat_phase, rng=False)` because `flat_phase` takes only a shape argument.
Any attribute of a Chromatix element that is specified as a possibly trainable parameter can be initialized using `chromatix.utils.trainable` in order to make it trainable. Otherwise, the attribute will be initialized (using either an `Array`, `float`, or `Callable` that takes a shape argument as specified in the documentation for that function) as non-trainable state of that element. If you are initializing an attribute as trainable using an initialization function, then you can specify whether that function requires a `jax.random.PRNGKey` or not. For example, if you are initializing the pixels of a phase mask with the `flat_phase` function, then you can use `trainable(flat_phase, rng=False)` because `flat_phase` takes only a shape argument.

For example:
!!! warning
In order to use `trainable`, you must be using a Chromatix element as shown
below. This function will not work if you are making a custom `nn.Module`
using Flax and want something to always be a trainable parameter. In
that case, you must use `self.param` as shown in the [Flax documentation]
(https://flax.readthedocs.io/en/latest/api_reference/flax.linen/
module.html#flax.linen.Module.param).

Here's an example of how to use `trainable`:

```python
import jax
from chromatix.elements import ThinLens, PhaseMask
from chromatix.utils import trainable, flat_phase
from chromatix.elements import ThinLens, PhaseMask, trainable
from chromatix.utils import flat_phase

# Refractive index is trainable and initialized to 1.33
# Focal distance and NA are not trainable
Expand Down
322 changes: 322 additions & 0 deletions docs/examples/bandlimited_angular_spectrum.ipynb

Large diffs are not rendered by default.

149 changes: 149 additions & 0 deletions docs/examples/filaments.ipynb

Large diffs are not rendered by default.

326 changes: 326 additions & 0 deletions docs/examples/gabor_hologram.ipynb

Large diffs are not rendered by default.

150 changes: 150 additions & 0 deletions docs/examples/pollen.ipynb

Large diffs are not rendered by default.

217 changes: 217 additions & 0 deletions docs/examples/sas.ipynb

Large diffs are not rendered by default.

96 changes: 40 additions & 56 deletions docs/examples/seidel_fitting.ipynb

Large diffs are not rendered by default.

85 changes: 43 additions & 42 deletions docs/examples/zernike_fitting.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ Welcome to `chromatix`, a differentiable wave optics library built using `jax` w

Here are some of the cool things we've already built with `chromatix`:

- [**Holoscope**](docs/examples/holoscope.ipynb): PSF engineering to optimally encode a 3D volume into a 2D image.
- [**Computer Generated Holography**](docs/examples/cgh.ipynb): optimizing a phase mask to produce a 3D hologram.
- [**Aberration Phase Retrieval**](docs/examples/zernike_fitting.ipynb): fitting Zernike coefficients to a measured aberrated PSF.
- [**Holoscope**](examples/holoscope.ipynb): PSF engineering to optimally encode a 3D volume into a 2D image.
- [**Computer Generated Holography**](examples/cgh.ipynb): optimizing a phase mask to produce a 3D hologram.
- [**Aberration Phase Retrieval**](examples/zernike_fitting.ipynb): fitting Zernike coefficients to a measured aberrated PSF.

Chromatix describes optical systems as sequences of sources and optical elements, composed in a similar style as neural network layers. These elements pass `Field` objects to each other, which contain both the tensor representation of the field at particular planes as well as information about the spatial sampling of the field and its spectrum. Typically, a user will not have to construct or deal with these `Field` objects unless they want to, but they are how `chromatix` can keep track of a lot of details of a simulation under the hood. Here's a very brief example of using `chromatix` to calculate the intensity of a widefield PSF (point spread function) at a single wavelength by describing a 4f system with a flat phase mask:

Expand Down
59 changes: 35 additions & 24 deletions docs/installing.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,41 @@

## System Requirements

Chromatix is based on [`jax`](https://github.com/google/jax) which only works
on macOS and Linux, not Windows.
Chromatix is based on [`jax`](https://github.com/google/jax) which can be
installed on macOS, Linux (Ubuntu), and Windows.

If you would like to run simulations on GPU, you will need an NVIDIA GPU with
CUDA support.
CUDA support. Ubuntu installations can take advantage of NVIDIA GPUs assuming
a recent NVIDIA driver has been installed. Windows installations can take
advantage of NVIDIA GPUs by installing through WSL2 (Ubuntu) on an up to date
Windows 10+ installation with a recent NVIDIA driver installed on Windows.
Installations on macOS are CPU only on both Intel and Apple Silicon, with very
limited GPU support.

!!! warning
Installing `jax` automatically through dependencies in `pyproject.toml`
can have some issues, as the CUDA version for your environment won't be
automatically detected. We recommend installing `jax` first as described in
the [`jax` README](https://github.com/google/jax#pip-installation-gpu-cuda)
Installing `jax` automatically through dependencies in `pyproject.toml` can
have some issues, e.g. a CUDA capable version of `jax` will not be installed
by default. We recommend installing `jax` first as described in the [`jax` README](https://github.com/google/jax?tab=readme-ov-file#installation)
in order to make sure that you install the version with appropriate CUDA
support for running on GPUs, if desired. Also see our section on installing
with `conda` below if you wouuld like to avoid installing your own CUDA
and/or building `jax` from source.
support for running on GPUs, if desired.

## Using `pip`

Chromatix can be installed on any supported operating system with Python 3.10+.
First install `jax` as described in the [`jax` README](https://github.com/google/jax?tab=readme-ov-file#installation).
NVIDIA support will be automatically installed if you install with `pip install jax["cuda12"]`.
Note that `jax` currently only supports CUDA 12. If your NVIDIA driver is compatible with CUDA 12
but is older than the version that the default `jax` installation is built for using `pip`, you
may see a warning when running your code that `jax` has disabled parallel compilation. This is
not an error and your code should still use the GPU, but it may take longer to compile before running.

!!! info
If you are on Windows 10+ and want NVIDIA GPU support, first make sure
you have an [up to date driver installed](https://www.nvidia.com/download/index.aspx)
for Windows. Then, [install WSL2](https://learn.microsoft.com/en-us/windows/wsl/install)
so that you have a terminal with Ubuntu running in WSL2. If you now install `jax`
using the instructions above, you should automatically get GPU support.

Once you have installed `jax`, you can install `chromatix` using:
```bash
$ pip install git+https://github.com/chromatix-team/chromatix.git
Expand All @@ -30,23 +47,17 @@ $ git clone https://github.com/chromatix-team/chromatix
$ cd chromatix
$ pip install -e .
```
Editable installations for development are recommended if you would like to
make changes to the internals of Chromatix or add new features (pull requests
welcomed!). Otherwise, please use the first installation command to get the
latest version of Chromatix.

Another option for development is to use a Python project management tool such
as [`Hatch`](https://hatch.pypa.io/latest/).

## Using `conda`

We do not package `chromatix` for `conda` because `jax` is also not officially
packaged for `conda`. However, if you would like to install `chromatix` into a `conda` environment
and also use a `conda` installation of CUDA, you can use the following instructions:

After creating and activating a `conda` environment with a supported Python version:
```bash
$ conda install -c conda-forge cudatoolkit=11.X
$ conda install -c conda-forge cudnn=A.B
$ conda install -c nvidia cuda-nvcc
$ pip install --upgrade "jax[cuda11_cudnnAB]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
$ pip install git+https://github.com/chromatix-team/chromatix.git
```
You will have to replace `X` above with the appropriate version supported by your graphics driver (e.g. `11.4`), and you must ensure
that `A` and `B` are the same for both the installation of `cudnn` and in the options when installing `jax` (e.g. `8.2` and `82`). You can see the versions of `cudnn`
for which `jax` has been packaged at: [https://storage.googleapis.com/jax-releases/jax_cuda_releases.html](https://storage.googleapis.com/jax-releases/jax_cuda_releases.html).
packaged for `conda`. However, if you would like to install `chromatix` into a
`conda` environment, you can [first create and activate a `conda` environment](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#)
with a supported Python version (3.10+), and then follow the `pip` installation instructions above.
11 changes: 6 additions & 5 deletions examples/parallel_imaging.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from chromatix.systems import Microscope, Optical4FSystemPSF
from chromatix.elements import BasicSensor, trainable
from chromatix.utils import flat_phase
from functools import partial
from time import perf_counter_ns

import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from time import perf_counter_ns
from chromatix.elements import BasicSensor, trainable
from chromatix.systems import Microscope, Optical4FSystemPSF
from chromatix.utils import flat_phase

num_devices = 4
num_planes_per_device = 32
Expand Down
7 changes: 4 additions & 3 deletions examples/parallel_psf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from chromatix.elements import ObjectivePointSource, PhaseMask, FFLens
from chromatix import OpticalSystem
from time import perf_counter_ns

import jax
import jax.numpy as jnp
import numpy as np
from time import perf_counter_ns
from chromatix import OpticalSystem
from chromatix.elements import FFLens, ObjectivePointSource, PhaseMask

num_devices = 4
num_planes_per_device = 32
Expand Down
16 changes: 9 additions & 7 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,23 @@ nav:
- Training: training.ipynb
- Higher Rank Fields: higher_rank.ipynb
- FAQ: FAQ.md
- API:
- Field: api/field.md
- Functional: api/functional.md
- Elements: api/elements.md
- Systems: api/systems.md
- Operations: api/ops.md
- Utilities: api/utils.md
- Examples:
- Holoscope: examples/holoscope.ipynb
- Computer Generated Holography: examples/cgh.ipynb
- Digital Micromirror Device: examples/dmd.ipynb
- Scalable Angular Spectrum: examples/sas.ipynb
- Bandlimited Angular Spectrum: examples/bandlimited_angular_spectrum.ipynb
- Fourier Ptychography: examples/fourier_ptychography.md
- Synchrotron X-ray Holography: examples/tomography.md
- Seidel Aberration Fitting: examples/seidel_fitting.ipynb
- Zernike Aberration Fitting: examples/zernike_fitting.ipynb
- API:
- Field: api/field.md
- Functional: api/functional.md
- Elements: api/elements.md
- Systems: api/systems.md
- Operations: api/ops.md
- Utilities: api/utils.md

theme:
logo: media/logo_symbol_white.png
Expand Down
15 changes: 11 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,21 @@ name = "chromatix"
authors = [{name = "Chromatix Team"}]
description = "Differentiable wave optics library using JAX"
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.10"
license = {text = "MIT"}
dependencies = ["jax >= 0.4.1", "einops >= 0.6.0", "flax >= 0.6.3", "chex>=0.1.5", "optax >=0.1.4", "scipy >= 1.10.0"]
version = "0.1.3"
version = "0.3.0"

[project.optional-dependencies]
dev = ["black >= 23.1.0", "mypy>= 0.991", "pytest>=7.2.0", "ruff >= 0.0.246"]
dev = ["mypy>= 0.991", "pytest>=7.2.0", "ruff >= 0.4.8"]
docs = ["mkdocs >= 1.4.2", "mkdocs-material >= 9.0.6", "mkdocstrings-python >= 0.8.3", "mkdocs-jupyter"]

[tool.ruff]
ignore = ["F401", "F403"] #ignore unused imports errors
target-version = "py310"

[tool.ruff.lint]
ignore = ["E402"]
extend-select = ["I"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401", "F403"]
2 changes: 1 addition & 1 deletion src/chromatix/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .field import ScalarField, VectorField, Field
from .field import Field, ScalarField, VectorField
from .systems import OpticalSystem
58 changes: 58 additions & 0 deletions src/chromatix/data/objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import imageio
import jax.numpy as jnp
import matplotlib.pyplot as plt
from skimage import img_as_ubyte


def create_radial_pattern(shape):
"""
Create a basic radial pattern image.
Args:
shape (tuple): Shape of the image (height, width).
Returns:
jnp.ndarray: Radial pattern image.
"""
# Create a grid of coordinates
y, x = jnp.indices(shape)

# Calculate the center of the image
center_y, center_x = shape[0] // 2, shape[1] // 2

# Compute the distances from the center
distances = jnp.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)

# Normalize distances to range [0, 2*pi] for phase pattern
max_distance = jnp.sqrt(center_x**2 + center_y**2)
phase_pattern = (distances / max_distance) * 2 * jnp.pi

return phase_pattern


def save_phase_pattern():
# Create the radial pattern
shape = (512, 512)
radial_pattern = create_radial_pattern(shape)

# Save the pattern as a PNG file
plt.imshow(radial_pattern, cmap="hsv")
plt.colorbar()
plt.title("Radial Phase Pattern")
plt.axis("off") # Hide the axis
plt.savefig("data/radial_pattern.png", bbox_inches="tight", pad_inches=0)
plt.show()


def normalize_grayscale_image(input_path, output_path):
# Read the image
img = imageio.imread(input_path)

# Normalize the grayscale image
normalized_img = img / img.max()

# Convert the normalized image to 8-bit unsigned integer format
normalized_img_ubyte = img_as_ubyte(normalized_img)

# Save the normalized grayscale image as a PNG
imageio.imsave(output_path, normalized_img_ubyte)
Loading

0 comments on commit 432772b

Please sign in to comment.