-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6e5fbd0
commit 4ee9de8
Showing
4 changed files
with
110 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .base import PromptNode | ||
from .nodes import GuidedPromptNode | ||
|
||
__all__ = [ | ||
"PromptNode", | ||
"GuidedPromptNode", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from typing import Optional, Dict, Union | ||
|
||
from haystack.nodes import PromptNode as HaystackPromptNode | ||
from haystack.nodes.prompt import PromptTemplate as HaystackPromptTemplate | ||
|
||
|
||
class Node(ABC): | ||
@abstractmethod | ||
def run(self, prompt_template): | ||
pass | ||
|
||
|
||
class PromptNode(Node): | ||
def __init__(self, model_name_or_path: str, *args, **kwargs) -> None: | ||
self._prompt_node = HaystackPromptNode(model_name_or_path, *args, **kwargs) | ||
|
||
def run( | ||
self, | ||
prompt_template: Optional[Union[str, HaystackPromptTemplate]], | ||
invocation_context: Optional[Dict[str, any]] = None, | ||
): | ||
return self._prompt_node.run(prompt_template, invocation_context=invocation_context)[0]["results"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import json | ||
|
||
from typing import Optional, Dict, Union | ||
|
||
from haystack.nodes import PromptNode as HaystackPromptNode | ||
from haystack.nodes.prompt import PromptTemplate as HaystackPromptTemplate | ||
|
||
from .base import Node | ||
|
||
try: | ||
import outlines.models as models | ||
import outlines.text.generate as generate | ||
|
||
from pydantic import BaseModel | ||
|
||
import torch | ||
except ImportError as exc: | ||
raise ImportError("Try 'pip install outlines'") from exc | ||
|
||
|
||
class GuidedPromptNode(Node): | ||
def __init__( | ||
self, | ||
model_name_or_path: str, | ||
schema: Union[str, BaseModel], | ||
max_length: int = 100, | ||
device: Optional[str] = None, | ||
model_kwargs: Dict = None, | ||
manual_seed: Optional[int] = None, | ||
) -> None: | ||
self.max_length = max_length | ||
model_kwargs = model_kwargs or {} | ||
self._model = models.transformers(model_name_or_path, device=device, **model_kwargs) | ||
# JSON schema of class | ||
if not isinstance(schema, str): | ||
schema = json.dumps(schema.schema()) | ||
|
||
self._generator = generate.json( | ||
self._model, | ||
schema, | ||
self.max_length, | ||
) | ||
|
||
self.rng = torch.Generator(device=device) | ||
if manual_seed is not None: | ||
self.rng.manual_seed(manual_seed) | ||
|
||
def run(self, prompt_template: HaystackPromptTemplate, **kwargs): | ||
return self._generator(prompt_template.prompt_text, rng=self.rng) |