Skip to content

Commit

Permalink
feat: initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
junhsss committed Mar 21, 2023
0 parents commit 8dac0ad
Show file tree
Hide file tree
Showing 8 changed files with 639 additions and 0 deletions.
31 changes: 31 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

name: Upload Python Package

on:
release:
types: [created]

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__pycache__
.git
.DS_Store
poetry.lock
poetry.toml
.venv
.pre-commit-config.yaml
.mypy_cache
.ruff_cache
pyproject.toml
49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Consistency Models

Implementation of Consistency Models [(Song et al., 2023)](https://arxiv.org/abs/2303.01469) in PyTorch.

<br />

![image](./assets/consistency_models.png)

## Install

```sh
$ pip install consistency
```

## Usage

```python
import torch
from diffusers import UNet2DModel
from consistency import Consistency

consistency = Consistency(
model=UNet2DModel(sample_size=224),
learning_rate=1e-4,
data_std=0.5,
time_min=0.002,
time_max=80.0,
bins_min=2,
bins_max=150,
bins_rho=7,
initial_ema_decay=0.9,
)

samples = consistency.sample(16)
samples = consistency.sample(16, steps=5) # multi-step generation
```

`Consistency` is self-contained with the training logic and all necessary schedules. It subclasses `LightningModule`, so it's supposed to be used with `Lightning.Trainer`.

```python
trainer = Trainer(max_epochs=8000, accelerator="auto")
trainer.fit(consistency, some_dataloader)
```

A complete example can be found in [this script](https://github.com/junhsss/consistency-models/blob/main/examples/train.py).

## Result

[Example wandb workspace](https://wandb.ai/junhsss/consistency/runs/pn566sjt?workspace=user-), with a batch size of **512**, **~20K** steps on `cifar10`.
Binary file added assets/consistency_models.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions consistency/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .consistency import Consistency

__all__ = ["Consistency"]
Loading

0 comments on commit 8dac0ad

Please sign in to comment.