diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py index a8e81b0..eed4403 100644 --- a/lora_diffusion/dataset.py +++ b/lora_diffusion/dataset.py @@ -139,9 +139,11 @@ def __init__( self.captions = open(f"{instance_data_root}/caption.txt").readlines() else: - possibily_src_images = glob.glob( - str(instance_data_root) + "/*.jpg" - ) + glob.glob(str(instance_data_root) + "/*.png") + possibily_src_images = ( + glob.glob(str(instance_data_root) + "/*.jpg") + + glob.glob(str(instance_data_root) + "/*.png") + + glob.glob(str(instance_data_root) + "/*.jpeg") + ) possibily_src_images = ( set(possibily_src_images) - set(glob.glob(str(instance_data_root) + "/*mask.png")) @@ -203,6 +205,8 @@ def __init__( self._length = self.num_instance_images if class_data_root is not None: + assert NotImplementedError, "Prior preservation is not implemented yet." + self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(self.class_data_root.iterdir()) diff --git a/lora_diffusion/preprocess_files.py b/lora_diffusion/preprocess_files.py index bc54eea..bedb89f 100644 --- a/lora_diffusion/preprocess_files.py +++ b/lora_diffusion/preprocess_files.py @@ -121,6 +121,7 @@ def clipseg_mask_generator( @torch.no_grad() def blip_captioning_dataset( images: List[Image.Image], + text: Optional[str] = None, model_id: Literal[ "Salesforce/blip-image-captioning-large", "Salesforce/blip-image-captioning-base", @@ -139,7 +140,7 @@ def blip_captioning_dataset( captions = [] for image in tqdm(images): - inputs = processor(image, return_tensors="pt").to("cuda") + inputs = processor(image, text=text, return_tensors="pt").to("cuda") out = model.generate( **inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7 ) @@ -243,6 +244,7 @@ def _center_of_mass(mask: Image.Image): def load_and_save_masks_and_captions( files: Union[str, List[str]], output_dir: str, + caption_text: Optional[str] = None, target_prompts: Optional[Union[List[str], str]] = None, target_size: int = 512, crop_based_on_salience: bool = True, @@ -277,7 +279,7 @@ def load_and_save_masks_and_captions( # captions print(f"Generating {len(images)} captions...") - captions = blip_captioning_dataset(images) + captions = blip_captioning_dataset(images, text=caption_text) if target_prompts is None: target_prompts = captions diff --git a/setup.py b/setup.py index b27850c..d49c151 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="lora_diffusion", py_modules=["lora_diffusion"], - version="0.1.4", + version="0.1.5", description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.", author="Simo Ryu", packages=find_packages(),