Skip to content

Commit

Permalink
Merge remote-tracking branch 'base/main' into palp/model-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
palp committed Jul 26, 2023
2 parents 733d38b + e596332 commit 2cc5425
Show file tree
Hide file tree
Showing 31 changed files with 635 additions and 120 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/black.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: Run black
on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install venv
run: |
sudo apt-get -y install python3.10-venv
- uses: psf/black@stable
with:
options: "--check --verbose -l88"
src: "./sgm ./scripts ./main.py"
26 changes: 26 additions & 0 deletions .github/workflows/test-build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Build package

on:
push:
pull_request:

jobs:
build:
name: Build
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.10"]
requirements-file: ["pt2", "pt13"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements/${{ matrix.requirements-file }}.txt
pip install .
9 changes: 7 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# extensions
*.egg-info
*.py[cod]

# envs
.pt13
.pt2
.pt2_2

# directories
/checkpoints
/dist
/outputs
build
/build
/src
23 changes: 17 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@ This is assuming you have navigated to the `generative-models` root after clonin

```shell
# install required packages from pypi
python3 -m venv .pt1
source .pt1/bin/activate
pip3 install wheel
pip3 install -r requirements_pt13.txt
python3 -m venv .pt13
source .pt13/bin/activate
pip3 install -r requirements/pt13.txt
```

**PyTorch 2.0**
Expand All @@ -72,8 +71,20 @@ pip3 install -r requirements_pt13.txt
# install required packages from pypi
python3 -m venv .pt2
source .pt2/bin/activate
pip3 install wheel
pip3 install -r requirements_pt2.txt
pip3 install -r requirements/pt2.txt
```


#### 3. Install `sgm`

```shell
pip3 install .
```

#### 4. Install `sdata` for training

```shell
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
```

## Packaging
Expand Down
11 changes: 4 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,18 @@
import torch
import torchvision
import wandb
from PIL import Image
from matplotlib import pyplot as plt
from natsort import natsorted
from omegaconf import OmegaConf
from packaging import version
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_only

from sgm.util import (
exists,
instantiate_from_config,
isheatmap,
)
from sgm.util import exists, instantiate_from_config, isheatmap

MULTINODE_HACKS = True

Expand Down Expand Up @@ -910,11 +906,12 @@ def divein(*args, **kwargs):
trainer.test(model, data)
except RuntimeError as err:
if MULTINODE_HACKS:
import requests
import datetime
import os
import socket

import requests

device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
hostname = socket.gethostname()
ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
Expand Down
40 changes: 40 additions & 0 deletions requirements/pt13.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
black==23.7.0
chardet>=5.1.0
clip @ git+https://github.com/openai/CLIP.git
einops>=0.6.1
fairscale>=0.4.13
fire>=0.5.0
fsspec>=2023.6.0
invisible-watermark>=0.2.0
kornia==0.6.9
matplotlib>=3.7.2
natsort>=8.4.0
numpy>=1.24.4
omegaconf>=2.3.0
onnx<=1.12.0
open-clip-torch>=2.20.0
opencv-python==4.6.0.66
pandas>=2.0.3
pillow>=9.5.0
pudb>=2022.1.3
pytorch-lightning==1.8.5
pyyaml>=6.0.1
scipy>=1.10.1
streamlit>=1.25.0
tensorboardx==2.5.1
timm>=0.9.2
tokenizers==0.12.1
--extra-index-url https://download.pytorch.org/whl/cu117
torch==1.13.1+cu117
torchaudio==0.13.1
torchdata==0.5.1
torchmetrics>=1.0.1
torchvision==0.14.1+cu117
tqdm>=4.65.0
transformers==4.19.1
triton==2.0.0.post1
urllib3<1.27,>=1.25.4
wandb>=0.15.6
webdataset>=0.2.33
wheel>=0.41.0
xformers==0.0.16
39 changes: 39 additions & 0 deletions requirements/pt2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
black==23.7.0
chardet==5.1.0
clip @ git+https://github.com/openai/CLIP.git
einops>=0.6.1
fairscale>=0.4.13
fire>=0.5.0
fsspec>=2023.6.0
invisible-watermark>=0.2.0
kornia==0.6.9
matplotlib>=3.7.2
natsort>=8.4.0
ninja>=1.11.1
numpy>=1.24.4
omegaconf>=2.3.0
open-clip-torch>=2.20.0
opencv-python==4.6.0.66
pandas>=2.0.3
pillow>=9.5.0
pudb>=2022.1.3
pytorch-lightning==2.0.1
pyyaml>=6.0.1
scipy>=1.10.1
streamlit>=0.73.1
tensorboardx==2.6
timm>=0.9.2
tokenizers==0.12.1
torch>=2.0.1
torchaudio>=2.0.2
torchdata==0.6.1
torchmetrics>=1.0.1
torchvision>=0.15.2
tqdm>=4.65.0
transformers==4.19.1
triton==2.0.0
urllib3<1.27,>=1.25.4
wandb>=0.15.6
webdataset>=0.2.33
wheel>=0.41.0
xformers>=0.0.20
41 changes: 0 additions & 41 deletions requirements_pt13.txt

This file was deleted.

41 changes: 0 additions & 41 deletions requirements_pt2.txt

This file was deleted.

1 change: 1 addition & 0 deletions scripts/demo/sampling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from pytorch_lightning import seed_everything

from scripts.demo.streamlit_helpers import *
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import (
Expand Down
6 changes: 3 additions & 3 deletions scripts/demo/streamlit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@


from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
EulerAncestralSampler,
DPMPP2SAncestralSampler,
DPMPP2MSampler,
LinearMultistepSampler,
)
from sgm.inference.helpers import Img2ImgDiscretizationWrapper, embed_watermark
Expand Down
5 changes: 3 additions & 2 deletions scripts/util/detection/nsfw_and_watermark_dectection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import torch

import clip
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
import clip

RESOURCES_ROOT = "scripts/util/detection/"

Expand Down
5 changes: 2 additions & 3 deletions sgm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .data import StableDataModuleFromConfig
from .models import AutoencodingEngine, DiffusionEngine
from .util import instantiate_from_config, get_configs_path
from .util import get_configs_path, instantiate_from_config

__version__ = "0.0.1"
__version__ = "0.1.0"
4 changes: 2 additions & 2 deletions sgm/data/cifar10.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torchvision
import pytorch_lightning as pl
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


class CIFAR10DataDictWrapper(Dataset):
Expand Down
4 changes: 2 additions & 2 deletions sgm/data/mnist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torchvision
import pytorch_lightning as pl
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


class MNISTDataDictWrapper(Dataset):
Expand Down
6 changes: 3 additions & 3 deletions sgm/modules/autoencoding/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch
import torch.nn as nn
from einops import rearrange
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss

from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
from ..lpips.model.model import NLayerDiscriminator, weights_init
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss


def adopt_weight(weight, global_step, threshold=0, value=0.0):
Expand Down
Empty file.
1 change: 1 addition & 0 deletions sgm/modules/autoencoding/lpips/loss/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vgg.pth
Loading

0 comments on commit 2cc5425

Please sign in to comment.