Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fake_reply parameter to GPT4All.generate() #2935

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gpt4all-bindings/python/gpt4all/_pyllmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def prompt_model(
context_erase: float = 0.75,
reset_context: bool = False,
special: bool = False,
fake_reply: str = "",
):
"""
Generate response from model from a prompt.
Expand Down Expand Up @@ -537,7 +538,7 @@ def prompt_model(
True,
self.context,
special,
ctypes.c_char_p(),
ctypes.c_char_p(fake_reply.encode()) if fake_reply else ctypes.c_char_p(),
)


Expand Down
3 changes: 3 additions & 0 deletions gpt4all-bindings/python/gpt4all/gpt4all.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def generate(
n_batch: int = 8,
n_predict: int | None = None,
streaming: bool = False,
fake_reply: str = "",
callback: ResponseCallbackType = empty_response_callback,
) -> Any:
"""
Expand All @@ -513,6 +514,7 @@ def generate(
n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.
n_predict: Equivalent to max_tokens, exists for backwards compatibility.
streaming: If True, this method will instead return a generator that yields tokens as the model generates them.
fake_reply: A spoofed reply for the given prompt, used as a way to load chat history.
callback: A function with arguments token_id:int and response:str, which receives the tokens from the model as they are generated and stops the generation by returning False.

Returns:
Expand All @@ -529,6 +531,7 @@ def generate(
repeat_last_n=repeat_last_n,
n_batch=n_batch,
n_predict=n_predict if n_predict is not None else max_tokens,
fake_reply=fake_reply,
)

if self._history is not None:
Expand Down