Skip to content

Commit

Permalink
Merge pull request #168 from cloneofsimo/develop
Browse files Browse the repository at this point in the history
v0.1.5
  • Loading branch information
cloneofsimo authored Feb 2, 2023
2 parents c23803d + 624fa9c commit 848db91
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
10 changes: 7 additions & 3 deletions lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ def __init__(
self.captions = open(f"{instance_data_root}/caption.txt").readlines()

else:
possibily_src_images = glob.glob(
str(instance_data_root) + "/*.jpg"
) + glob.glob(str(instance_data_root) + "/*.png")
possibily_src_images = (
glob.glob(str(instance_data_root) + "/*.jpg")
+ glob.glob(str(instance_data_root) + "/*.png")
+ glob.glob(str(instance_data_root) + "/*.jpeg")
)
possibily_src_images = (
set(possibily_src_images)
- set(glob.glob(str(instance_data_root) + "/*mask.png"))
Expand Down Expand Up @@ -203,6 +205,8 @@ def __init__(
self._length = self.num_instance_images

if class_data_root is not None:
assert NotImplementedError, "Prior preservation is not implemented yet."

self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
Expand Down
6 changes: 4 additions & 2 deletions lora_diffusion/preprocess_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def clipseg_mask_generator(
@torch.no_grad()
def blip_captioning_dataset(
images: List[Image.Image],
text: Optional[str] = None,
model_id: Literal[
"Salesforce/blip-image-captioning-large",
"Salesforce/blip-image-captioning-base",
Expand All @@ -139,7 +140,7 @@ def blip_captioning_dataset(
captions = []

for image in tqdm(images):
inputs = processor(image, return_tensors="pt").to("cuda")
inputs = processor(image, text=text, return_tensors="pt").to("cuda")
out = model.generate(
**inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7
)
Expand Down Expand Up @@ -243,6 +244,7 @@ def _center_of_mass(mask: Image.Image):
def load_and_save_masks_and_captions(
files: Union[str, List[str]],
output_dir: str,
caption_text: Optional[str] = None,
target_prompts: Optional[Union[List[str], str]] = None,
target_size: int = 512,
crop_based_on_salience: bool = True,
Expand Down Expand Up @@ -277,7 +279,7 @@ def load_and_save_masks_and_captions(

# captions
print(f"Generating {len(images)} captions...")
captions = blip_captioning_dataset(images)
captions = blip_captioning_dataset(images, text=caption_text)

if target_prompts is None:
target_prompts = captions
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name="lora_diffusion",
py_modules=["lora_diffusion"],
version="0.1.4",
version="0.1.5",
description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
author="Simo Ryu",
packages=find_packages(),
Expand Down

0 comments on commit 848db91

Please sign in to comment.