Skip to content

Commit

Permalink
Working draft
Browse files Browse the repository at this point in the history
  • Loading branch information
HallerPatrick committed Nov 2, 2023
1 parent 6e5fbd0 commit 4ee9de8
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 22 deletions.
52 changes: 30 additions & 22 deletions src/fabricator/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def generate(
num_samples_to_generate: int = 10,
timeout_per_prompt: Optional[int] = None,
log_every_n_api_calls: int = 25,
dummy_response: Optional[Union[str, Callable]] = None
dummy_response: Optional[Union[str, Callable]] = None,
) -> Union[Dataset, Tuple[Dataset, Dataset]]:
"""Generate a dataset based on a prompt template and support examples.
Optionally, unlabeled examples can be provided to annotate unlabeled data.
Expand Down Expand Up @@ -93,8 +93,11 @@ def generate(
if fewshot_dataset:
self._assert_fewshot_dataset_matches_prompt(prompt_template, fewshot_dataset)

assert fewshot_sampling_strategy in [None, "uniform", "stratified"], \
"Sampling strategy must be 'uniform' or 'stratified'"
assert fewshot_sampling_strategy in [
None,
"uniform",
"stratified",
], "Sampling strategy must be 'uniform' or 'stratified'"

if fewshot_dataset and not fewshot_sampling_column:
fewshot_sampling_column = prompt_template.generate_data_for_column[0]
Expand All @@ -111,7 +114,7 @@ def generate(
num_samples_to_generate,
timeout_per_prompt,
log_every_n_api_calls,
dummy_response
dummy_response,
)

if return_unlabeled_dataset:
Expand All @@ -134,7 +137,6 @@ def _try_generate(
"""

if dummy_response:

if isinstance(dummy_response, str):
logger.info(f"Returning dummy response: {dummy_response}")
return dummy_response
Expand All @@ -152,7 +154,7 @@ def _try_generate(
prediction = self.prompt_node.run(
prompt_template=HaystackPromptTemplate(prompt=prompt_text),
invocation_context=invocation_context,
)[0]["results"]
)
except Exception as error:
logger.error(f"Error while generating example: {error}")
return None
Expand All @@ -172,7 +174,7 @@ def _inner_generate_loop(
num_samples_to_generate: int,
timeout_per_prompt: Optional[int],
log_every_n_api_calls: int = 25,
dummy_response: Optional[Union[str, Callable]] = None
dummy_response: Optional[Union[str, Callable]] = None,
):
current_tries_left = self._max_tries
current_log_file = self._setup_log(prompt_template)
Expand Down Expand Up @@ -200,8 +202,11 @@ def _inner_generate_loop(

if fewshot_dataset:
prompt_labels, fewshot_examples = self._sample_fewshot_examples(
prompt_template, fewshot_dataset, fewshot_sampling_strategy, fewshot_examples_per_class,
fewshot_sampling_column
prompt_template,
fewshot_dataset,
fewshot_sampling_strategy,
fewshot_examples_per_class,
fewshot_sampling_column,
)

prompt_text = prompt_template.get_prompt_text(prompt_labels, fewshot_examples)
Expand Down Expand Up @@ -231,6 +236,7 @@ def _inner_generate_loop(
f" {len(generated_dataset)} examples."
)
break
continue

if len(prediction) == 1:
prediction = prediction[0]
Expand Down Expand Up @@ -310,8 +316,9 @@ def _convert_prediction(self, prediction: str, target_type: type) -> Any:
return target_type(prediction)
except ValueError:
logger.warning(
"Could not convert prediction {} to type {}. "
"Returning original prediction.", repr(prediction), target_type
"Could not convert prediction {} to type {}. " "Returning original prediction.",
repr(prediction),
target_type,
)
return prediction

Expand All @@ -321,21 +328,20 @@ def _sample_fewshot_examples(
fewshot_dataset: Dataset,
fewshot_sampling_strategy: str,
fewshot_examples_per_class: int,
fewshot_sampling_column: str
fewshot_sampling_column: str,
) -> Tuple[Union[List[str], str], Dataset]:

if fewshot_sampling_strategy == "uniform":
prompt_labels = choice(prompt_template.label_options, 1)[0]
fewshot_examples = fewshot_dataset.filter(
lambda example: example[fewshot_sampling_column] == prompt_labels
).shuffle().select(range(fewshot_examples_per_class))
fewshot_examples = (
fewshot_dataset.filter(lambda example: example[fewshot_sampling_column] == prompt_labels)
.shuffle()
.select(range(fewshot_examples_per_class))
)

elif fewshot_sampling_strategy == "stratified":
prompt_labels = prompt_template.label_options
fewshot_examples = single_label_stratified_sample(
fewshot_dataset,
fewshot_sampling_column,
fewshot_examples_per_class
fewshot_dataset, fewshot_sampling_column, fewshot_examples_per_class
)

else:
Expand All @@ -345,9 +351,11 @@ def _sample_fewshot_examples(
else:
fewshot_examples = fewshot_dataset.shuffle()

assert len(fewshot_examples) > 0, f"Could not find any fewshot examples for label(s) {prompt_labels}." \
f"Ensure that labels of fewshot examples match the label_options " \
f"from the prompt."
assert len(fewshot_examples) > 0, (
f"Could not find any fewshot examples for label(s) {prompt_labels}."
f"Ensure that labels of fewshot examples match the label_options "
f"from the prompt."
)

return prompt_labels, fewshot_examples

Expand Down
7 changes: 7 additions & 0 deletions src/fabricator/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .base import PromptNode
from .nodes import GuidedPromptNode

__all__ = [
"PromptNode",
"GuidedPromptNode",
]
24 changes: 24 additions & 0 deletions src/fabricator/nodes/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from abc import ABC, abstractmethod

from typing import Optional, Dict, Union

from haystack.nodes import PromptNode as HaystackPromptNode
from haystack.nodes.prompt import PromptTemplate as HaystackPromptTemplate


class Node(ABC):
@abstractmethod
def run(self, prompt_template):
pass


class PromptNode(Node):
def __init__(self, model_name_or_path: str, *args, **kwargs) -> None:
self._prompt_node = HaystackPromptNode(model_name_or_path, *args, **kwargs)

def run(
self,
prompt_template: Optional[Union[str, HaystackPromptTemplate]],
invocation_context: Optional[Dict[str, any]] = None,
):
return self._prompt_node.run(prompt_template, invocation_context=invocation_context)[0]["results"]
49 changes: 49 additions & 0 deletions src/fabricator/nodes/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json

from typing import Optional, Dict, Union

from haystack.nodes import PromptNode as HaystackPromptNode
from haystack.nodes.prompt import PromptTemplate as HaystackPromptTemplate

from .base import Node

try:
import outlines.models as models
import outlines.text.generate as generate

from pydantic import BaseModel

import torch
except ImportError as exc:
raise ImportError("Try 'pip install outlines'") from exc


class GuidedPromptNode(Node):
def __init__(
self,
model_name_or_path: str,
schema: Union[str, BaseModel],
max_length: int = 100,
device: Optional[str] = None,
model_kwargs: Dict = None,
manual_seed: Optional[int] = None,
) -> None:
self.max_length = max_length
model_kwargs = model_kwargs or {}
self._model = models.transformers(model_name_or_path, device=device, **model_kwargs)
# JSON schema of class
if not isinstance(schema, str):
schema = json.dumps(schema.schema())

self._generator = generate.json(
self._model,
schema,
self.max_length,
)

self.rng = torch.Generator(device=device)
if manual_seed is not None:
self.rng.manual_seed(manual_seed)

def run(self, prompt_template: HaystackPromptTemplate, **kwargs):
return self._generator(prompt_template.prompt_text, rng=self.rng)

0 comments on commit 4ee9de8

Please sign in to comment.