diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index c863817dfb07..8f230d5ec14b 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -176,6 +176,8 @@ def __init__( n_ctx: int = 2048, ngl: int = 100, verbose: bool = False, + proxies=None, + verify_ssl: bool = True ): """ Constructor @@ -201,11 +203,17 @@ def __init__( n_ctx: Maximum size of context window ngl: Number of GPU layers to use (Vulkan) verbose: If True, print debug messages. + proxies: A dictionary of proxies to be used for remote calls. + verify_ssl: If true, verify SSL certificates. Defaults to true. Note it is generally not recommended to skip SSL verification. """ + if proxies is None: + proxies = {} self.model_type = model_type self._history: list[MessageType] | None = None self._current_prompt_template: str = "{0}" + self._proxies = proxies + self._verify_ssl = verify_ssl device_init = None if sys.platform == "darwin": @@ -267,14 +275,24 @@ def current_chat_session(self) -> list[MessageType] | None: return None if self._history is None else list(self._history) @staticmethod - def list_models() -> list[ConfigType]: + def list_models( + proxies=None, + verify_ssl: bool = True + ) -> list[ConfigType]: """ Fetch model list from https://gpt4all.io/models/models3.json. + Args: + proxies: A dictionary of proxies to be used for remote calls. + verify_ssl: If true, verify SSL certificates. Defaults to true. + Returns: Model list in JSON format. """ - resp = requests.get("https://gpt4all.io/models/models3.json") + if proxies is None: + proxies = {} + + resp = requests.get("https://gpt4all.io/models/models3.json", proxies=proxies, verify=verify_ssl) if resp.status_code != 200: raise ValueError(f'Request failed: HTTP {resp.status_code} {resp.reason}') return resp.json() @@ -286,6 +304,8 @@ def retrieve_model( model_path: str | os.PathLike[str] | None = None, allow_download: bool = True, verbose: bool = False, + proxies: dict = None, + verify_ssl: bool = True, ) -> ConfigType: """ Find model file, and if it doesn't exist, download the model. @@ -296,17 +316,21 @@ def retrieve_model( ~/.cache/gpt4all/. allow_download: Allow API to download model from gpt4all.io. Default is True. verbose: If True (default), print debug messages. + proxies: A dictionary of proxies to be used for remote calls. + verify_ssl: If true, verify SSL certificates. Defaults to true. Returns: Model config. """ model_filename = append_extension_if_missing(model_name) + if proxies is None: + proxies = {} # get the config for the model config: ConfigType = {} if allow_download: - available_models = cls.list_models() + available_models = cls.list_models(proxies=proxies, verify_ssl=verify_ssl) for m in available_models: if model_filename == m["filename"]: @@ -354,6 +378,8 @@ def download_model( url: str | None = None, expected_size: int | None = None, expected_md5: str | None = None, + proxies=None, + verify_ssl: bool = True ) -> str | os.PathLike[str]: """ Download model from gpt4all.io. @@ -365,7 +391,8 @@ def download_model( url: the models remote url (e.g. may be hosted on HF) expected_size: The expected size of the download. expected_md5: The expected MD5 hash of the download. - + proxies: A dictionary of proxies to be used for remote calls. + verify_ssl: If true, verify SSL certificates. Defaults to true. Returns: Model file destination. """ @@ -373,6 +400,8 @@ def download_model( # Download model if url is None: url = f"https://gpt4all.io/models/gguf/{model_filename}" + if proxies is None: + proxies = {} def make_request(offset=None): headers = {} @@ -380,7 +409,7 @@ def make_request(offset=None): print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr) headers['Range'] = f'bytes={offset}-' # resume incomplete response headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges - response = requests.get(url, stream=True, headers=headers) + response = requests.get(url, stream=True, headers=headers, proxies=proxies, verify=verify_ssl) if response.status_code not in (200, 206): raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}') if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):