Skip to content

Commit

Permalink
Fixed attention_slice size to be compatible wih SD2 on the pipes
Browse files Browse the repository at this point in the history
  • Loading branch information
Skquark committed Nov 25, 2022
1 parent 48c6384 commit 1afae88
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 18 deletions.
11 changes: 8 additions & 3 deletions examples/community/clip_guided_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,14 @@ def __init__(

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)

self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
Expand Down
10 changes: 7 additions & 3 deletions examples/community/composable_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,13 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
Expand Down
11 changes: 8 additions & 3 deletions examples/community/imagic_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)

self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
Expand Down
11 changes: 8 additions & 3 deletions examples/community/interpolate_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)

self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
Expand Down
11 changes: 8 additions & 3 deletions examples/community/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)

self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)

self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
Expand Down

0 comments on commit 1afae88

Please sign in to comment.