Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Proxy and SSL Config Options to Python SDK #3180

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions gpt4all-bindings/python/gpt4all/gpt4all.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def __init__(
n_ctx: int = 2048,
ngl: int = 100,
verbose: bool = False,
proxies=None,
verify_ssl: bool = True
):
"""
Constructor
Expand All @@ -201,11 +203,17 @@ def __init__(
n_ctx: Maximum size of context window
ngl: Number of GPU layers to use (Vulkan)
verbose: If True, print debug messages.
proxies: A dictionary of proxies to be used for remote calls.
verify_ssl: If true, verify SSL certificates. Defaults to true. Note it is generally not recommended to skip SSL verification.
"""

if proxies is None:
proxies = {}
self.model_type = model_type
self._history: list[MessageType] | None = None
self._current_prompt_template: str = "{0}"
self._proxies = proxies
self._verify_ssl = verify_ssl

device_init = None
if sys.platform == "darwin":
Expand Down Expand Up @@ -267,14 +275,24 @@ def current_chat_session(self) -> list[MessageType] | None:
return None if self._history is None else list(self._history)

@staticmethod
def list_models() -> list[ConfigType]:
def list_models(
proxies=None,
verify_ssl: bool = True
) -> list[ConfigType]:
"""
Fetch model list from https://gpt4all.io/models/models3.json.

Args:
proxies: A dictionary of proxies to be used for remote calls.
verify_ssl: If true, verify SSL certificates. Defaults to true.

Returns:
Model list in JSON format.
"""
resp = requests.get("https://gpt4all.io/models/models3.json")
if proxies is None:
proxies = {}

resp = requests.get("https://gpt4all.io/models/models3.json", proxies=proxies, verify=verify_ssl)
if resp.status_code != 200:
raise ValueError(f'Request failed: HTTP {resp.status_code} {resp.reason}')
return resp.json()
Expand All @@ -286,6 +304,8 @@ def retrieve_model(
model_path: str | os.PathLike[str] | None = None,
allow_download: bool = True,
verbose: bool = False,
proxies: dict = None,
verify_ssl: bool = True,
) -> ConfigType:
"""
Find model file, and if it doesn't exist, download the model.
Expand All @@ -296,17 +316,21 @@ def retrieve_model(
~/.cache/gpt4all/.
allow_download: Allow API to download model from gpt4all.io. Default is True.
verbose: If True (default), print debug messages.
proxies: A dictionary of proxies to be used for remote calls.
verify_ssl: If true, verify SSL certificates. Defaults to true.

Returns:
Model config.
"""

model_filename = append_extension_if_missing(model_name)
if proxies is None:
proxies = {}

# get the config for the model
config: ConfigType = {}
if allow_download:
available_models = cls.list_models()
available_models = cls.list_models(proxies=proxies, verify_ssl=verify_ssl)

for m in available_models:
if model_filename == m["filename"]:
Expand Down Expand Up @@ -354,6 +378,8 @@ def download_model(
url: str | None = None,
expected_size: int | None = None,
expected_md5: str | None = None,
proxies=None,
verify_ssl: bool = True
) -> str | os.PathLike[str]:
"""
Download model from gpt4all.io.
Expand All @@ -365,22 +391,25 @@ def download_model(
url: the models remote url (e.g. may be hosted on HF)
expected_size: The expected size of the download.
expected_md5: The expected MD5 hash of the download.

proxies: A dictionary of proxies to be used for remote calls.
verify_ssl: If true, verify SSL certificates. Defaults to true.
Returns:
Model file destination.
"""

# Download model
if url is None:
url = f"https://gpt4all.io/models/gguf/{model_filename}"
if proxies is None:
proxies = {}

def make_request(offset=None):
headers = {}
if offset:
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
headers['Range'] = f'bytes={offset}-' # resume incomplete response
headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges
response = requests.get(url, stream=True, headers=headers)
response = requests.get(url, stream=True, headers=headers, proxies=proxies, verify=verify_ssl)
if response.status_code not in (200, 206):
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
Expand Down