Skip to content

Commit

Permalink
change the fl.Slincing interface
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Dec 12, 2023
1 parent 06822af commit 7dbf396
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 15 deletions.
12 changes: 7 additions & 5 deletions src/refiners/fluxion/layers/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,18 @@ def forward(self, x: Tensor) -> Tensor:


class Slicing(Module):
def __init__(self, dim: int, start: int, length: int) -> None:
def __init__(self, dim: int = 0, start: int = 0, end: int = -1, step: int = 1) -> None:
super().__init__()
self.dim = dim
self.start = start
self.length = length
self.end = end
self.step = step

def forward(self, x: Tensor) -> Tensor:
if self.length < 0:
return x.narrow(self.dim, self.start, x.shape[self.dim] - self.start + self.length + 1)
return x.narrow(self.dim, self.start, self.length)
start = self.start if self.start >= 0 else x.shape[self.dim] + self.start
end = self.end if self.end >= 0 else x.shape[self.dim] + self.end
indices = torch.arange(start=start, end=end, step=self.step, device=x.device)
return x.index_select(self.dim, indices)


class Squeeze(Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self, device: Device | str | None = None, dtype: DType | None = Non
),
Chain(
Conv2d(in_channels=8, out_channels=8, kernel_size=1, device=device, dtype=dtype),
Slicing(dim=1, start=0, length=4),
Slicing(dim=1, end=4),
),
)

Expand Down
10 changes: 6 additions & 4 deletions src/refiners/foundationals/latent_diffusion/image_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,11 @@ def __init__(
InjectionPoint(), # Wq
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, start=0, length=text_sequence_length),
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wk
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
fl.Slicing(dim=1, start=text_sequence_length, end=image_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
Expand All @@ -280,11 +280,13 @@ def __init__(
),
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, start=0, length=text_sequence_length),
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wv
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
fl.Slicing(
dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length + 1
),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self.scale = scale
super().__init__(
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
Slicing(dim=1, start=0, length=4), # support inpainting
Slicing(dim=1, end=4), # support inpainting
DownBlocks(in_channels=4, device=device, dtype=dtype),
MiddleBlock(device=device, dtype=dtype),
)
Expand Down
8 changes: 4 additions & 4 deletions src/refiners/foundationals/segment_anything/mask_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
super().__init__(
*[
fl.Chain(
fl.Slicing(dim=1, start=i + 1, length=1),
fl.Slicing(dim=1, start=i + 1, end=i + 2),
fl.MultiLinear(
input_dim=embedding_dim,
output_dim=embedding_dim // 8,
Expand Down Expand Up @@ -156,7 +156,7 @@ def __init__(
),
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
),
fl.Slicing(dim=1, start=1, length=num_mask_tokens),
fl.Slicing(dim=1, start=1, end=num_mask_tokens + 1),
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
)

Expand All @@ -173,7 +173,7 @@ def __init__(
self.embedding_dim = embedding_dim
self.num_layers = num_layers
super().__init__(
fl.Slicing(dim=1, start=0, length=1),
fl.Slicing(dim=1, start=0, end=2),
fl.Squeeze(dim=0),
fl.MultiLinear(
input_dim=embedding_dim,
Expand All @@ -183,7 +183,7 @@ def __init__(
device=device,
dtype=dtype,
),
fl.Slicing(dim=-1, start=1, length=num_mask_tokens),
fl.Slicing(dim=-1, start=1, end=num_mask_tokens + 1),
)


Expand Down

0 comments on commit 7dbf396

Please sign in to comment.