From 3eaba7f96760002673eded2828fda4aaa087a578 Mon Sep 17 00:00:00 2001 From: Yilun Huang Date: Mon, 19 Aug 2024 17:38:34 +0800 Subject: [PATCH] + add use_cuda for get_model funcs in two OPs (#389) --- data_juicer/ops/filter/video_aesthetics_filter.py | 4 +++- data_juicer/ops/mapper/image_diffusion_mapper.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/data_juicer/ops/filter/video_aesthetics_filter.py b/data_juicer/ops/filter/video_aesthetics_filter.py index 131c82f4b..ddb13aa4f 100644 --- a/data_juicer/ops/filter/video_aesthetics_filter.py +++ b/data_juicer/ops/filter/video_aesthetics_filter.py @@ -159,7 +159,9 @@ def compute_stats(self, sample, rank=None, context=False): if len(frame_images) > 0: # compute aesthetics_scores - model, processor = get_model(self.model_key, rank=rank) + model, processor = get_model(self.model_key, + rank=rank, + use_cuda=self.use_cuda()) inputs = processor(images=frame_images, return_tensors='pt').to(model.device) with torch.no_grad(): diff --git a/data_juicer/ops/mapper/image_diffusion_mapper.py b/data_juicer/ops/mapper/image_diffusion_mapper.py index 776d08cb8..26f0f8403 100644 --- a/data_juicer/ops/mapper/image_diffusion_mapper.py +++ b/data_juicer/ops/mapper/image_diffusion_mapper.py @@ -126,7 +126,9 @@ def _real_guidance(self, caption: str, image: Image.Image, rank=None): canvas = image.resize((512, 512), Image.BILINEAR) prompt = caption - diffusion_model = get_model(model_key=self.model_key, rank=rank) + diffusion_model = get_model(model_key=self.model_key, + rank=rank, + use_cuda=self.use_cuda()) kwargs = dict(image=canvas, prompt=[prompt],