diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index be74d8699..b0ce54532 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -330,7 +330,8 @@ def __init__(self, config: Blip2Config) -> None: def prepare_simple_aesthetics_model(pretrained_model_name_or_path, - return_model=True): + return_model=True, + trust_remote_code=False): """ Prepare and load a simple aesthetics model. @@ -344,21 +345,25 @@ def prepare_simple_aesthetics_model(pretrained_model_name_or_path, AestheticsPredictorV2ReLU) from transformers import CLIPProcessor - processor = CLIPProcessor.from_pretrained(pretrained_model_name_or_path) + processor = CLIPProcessor.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code) if not return_model: return processor else: if 'v1' in pretrained_model_name_or_path: model = AestheticsPredictorV1.from_pretrained( - pretrained_model_name_or_path) + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code) elif ('v2' in pretrained_model_name_or_path and 'linear' in pretrained_model_name_or_path): model = AestheticsPredictorV2Linear.from_pretrained( - pretrained_model_name_or_path) + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code) elif ('v2' in pretrained_model_name_or_path and 'relu' in pretrained_model_name_or_path): model = AestheticsPredictorV2ReLU.from_pretrained( - pretrained_model_name_or_path) + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code) else: raise ValueError( 'Not support {}'.format(pretrained_model_name_or_path)) @@ -439,7 +444,8 @@ def decompress_model(compressed_model_path): def prepare_diffusion_model(pretrained_model_name_or_path, diffusion_type, torch_dtype='fp32', - revision='main'): + revision='main', + trust_remote_code=False): """ Prepare and load an Diffusion model from HuggingFace. @@ -493,7 +499,8 @@ def prepare_diffusion_model(pretrained_model_name_or_path, model = pipeline.from_pretrained(pretrained_model_name_or_path, revision=revision, - torch_dtype=torch_dtype) + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code) return model