Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: chex-ify testsuite #12

Merged
merged 8 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ $ uv sync
## Running

```shell
$ uv jflux
$ uv run jflux
```

## References
Expand Down
20 changes: 0 additions & 20 deletions .pre-commit-config.yaml

This file was deleted.

15 changes: 15 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"[python]": {
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.fixAll": "explicit",
"source.organizeImports": "explicit"
},
"editor.defaultFormatter": "charliermarsh.ruff",
},
}
4 changes: 0 additions & 4 deletions jflux/__main__.py

This file was deleted.

13 changes: 9 additions & 4 deletions jflux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

import jax
import jax.numpy as jnp
import numpy as np
from einops import rearrange
from fire import Fire
from flax import nnx
from jax.typing import DTypeLike
from PIL import Image

from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
Expand Down Expand Up @@ -124,7 +123,8 @@ def main(
by the index of the sample
prompt: Prompt used for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
num_steps: number of sampling steps
(default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
Expand Down Expand Up @@ -216,7 +216,12 @@ def main(
x = x.clip(-1, 1)
x = rearrange(x[0], "c h w -> h w c")

img = Image.fromarray((127.5 * (x + 1.0)))
x = 127.5 * (x + 1.0)
x_numpy = np.array(x.astype(jnp.uint8))
img = Image.fromarray(x_numpy)

img.save(fn, quality=95, subsampling=0)
idx += 1

if loop:
print("-" * 80)
Expand Down
2 changes: 1 addition & 1 deletion jflux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, params: FluxParams):
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" # noqa: E501
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
Expand Down
3 changes: 3 additions & 0 deletions jflux/port.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from einops import rearrange


##############################################################################################
# AUTOENCODER MODEL PORTING
##############################################################################################
Expand Down Expand Up @@ -481,3 +482,5 @@ def port_flux(flux, tensors):
tensors=tensors,
prefix="final_layer",
)

return flux
2 changes: 0 additions & 2 deletions jflux/util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import os
from dataclasses import dataclass

import jax
import torch # need for t5 and clip
from flax import nnx
from huggingface_hub import hf_hub_download
from jax import numpy as jnp
from jax.typing import DTypeLike
from safetensors import safe_open

from jflux.model import Flux, FluxParams
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies = [
"einops>=0.8.0",
"fire>=0.6.0",
"flax>=0.9.0",
"jflux",
# FIXME: Allow for local installation without GPUs as well `jax[cuda12]`
"jax>=0.4.31",
"mypy>=1.11.2",
Expand All @@ -22,6 +21,7 @@ dependencies = [
jflux = "jflux.cli:app"

[tool.uv]
package = true
dev-dependencies = [
"flux",
"pytest>=8.3.3",
Expand All @@ -32,7 +32,10 @@ jflux = { workspace = true }
flux = { git = "https://github.com/black-forest-labs/flux.git" }

[tool.ruff.lint]
select = ["I001"]
select = ["E", "F", "I001", "W"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just asking -- what does this do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So previously, we weren't utilising most of the nice ruff rules. This snippet enables some nice rules from https://docs.astral.sh/ruff/rules/


[tool.ruff.lint.isort]
lines-after-imports = 2

[tool.ruff.lint.pydocstyle]
convention = "google"
Expand Down
Empty file added tests/__init__.py
Empty file.
Loading