Skip to content

Commit

Permalink
feat: IL-414 custom id can be passed to create_dataset
Browse files Browse the repository at this point in the history
* `FileSystemDatasetRepository._write_data` now uses `write_utf8` instead of opening a file.
  This should hopefully prevent the Huggingface CI errors
  • Loading branch information
FelixFehseTNG committed Apr 3, 2024
1 parent d630a8b commit b631794
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@ class DatasetRepository(ABC):

@abstractmethod
def create_dataset(
self, examples: Iterable[Example[Input, ExpectedOutput]], dataset_name: str
self,
examples: Iterable[Example[Input, ExpectedOutput]],
dataset_name: str,
id: str | None = None,
) -> Dataset:
"""Creates a dataset from given :class:`Example`s and returns the ID of that dataset.
Args:
examples: An :class:`Iterable` of :class:`Example`s to be saved in the same dataset.
dataset_name: A name for the dataset.
id: The dataset ID. If `None`, an ID will be generated.
Returns:
The created :class:`Dataset`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ def create_dataset(
dataset_name: str,
id: str | None = None,
) -> Dataset:
if id is None:
dataset = Dataset(name=dataset_name)
else:
dataset = Dataset(name=dataset_name, id=id)
dataset = Dataset(name=dataset_name)
if id is not None:
dataset.id = id

self.mkdir(self._dataset_directory(dataset.id))

Expand Down Expand Up @@ -146,17 +145,11 @@ def _write_data(
file_path: Path,
data_to_write: Iterable[PydanticSerializable],
) -> None:
data = "\n".join(JsonSerializer(root=chunk).model_dump_json() for chunk in data_to_write)
data = "\n".join(
JsonSerializer(root=chunk).model_dump_json() for chunk in data_to_write
)
self.write_utf8(file_path, data, create_parents=True)

# with self._file_system.open(
# self.path_to_str(file_path), "w", encoding="utf-8"
# ) as file:
# for data_chunk in data_to_write:
# serialized_result = JsonSerializer(root=data_chunk)
# json_string = serialized_result.model_dump_json() + "\n"
# file.write(json_string)


class FileDatasetRepository(FileSystemDatasetRepository):
def __init__(self, root_directory: Path) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
import time
from pathlib import Path
from typing import Iterable, Optional
import typing
from typing import Optional

import huggingface_hub # type: ignore
from huggingface_hub import HfFileSystem, create_repo
from huggingface_hub.utils import HfHubHTTPError

from intelligence_layer.core.task import Input
from intelligence_layer.evaluation.dataset.domain import (
Dataset,
Example,
ExpectedOutput,
)
from intelligence_layer.evaluation.dataset.domain import Dataset
from intelligence_layer.evaluation.dataset.file_dataset_repository import (
FileSystemDatasetRepository,
)
Expand Down Expand Up @@ -50,25 +42,6 @@ def __init__(self, repository_id: str, token: str, private: bool) -> None:
self._repository_id = repository_id
self._file_system = file_system # for better type checks

# def create_dataset(
# self,
# examples: Iterable[Example[Input, ExpectedOutput]],
# dataset_name: str,
# id: str | None = None,
# ) -> Dataset:
# failures = 0
# exception = None
# while failures < 5:
# try:
# dataset = super().create_dataset(examples, dataset_name, id)
# return dataset
# except Exception as e:
# exception = typing.cast(HfHubHTTPError, e)
# failures += 1
# print(f"Failure {failures}")
# time.sleep(0.5)
# raise exception # RuntimeError("Cannot create dataset on Huggingface.")

def delete_repository(self) -> None:
huggingface_hub.delete_repo(
repo_id=self._repository_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@ def __init__(self) -> None:
] = {}

def create_dataset(
self, examples: Iterable[Example[Input, ExpectedOutput]], dataset_name: str
self,
examples: Iterable[Example[Input, ExpectedOutput]],
dataset_name: str,
id: str | None = None,
) -> Dataset:
dataset = Dataset(name=dataset_name)
if id is not None:
dataset.id = id
if dataset.id in self._datasets_and_examples:
raise ValueError(
f"Created random dataset ID already exists for dataset {dataset}. This should not happen."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def __init__(
self._huggingface_dataset = huggingface_dataset

def create_dataset(
self, examples: Iterable[Example[Input, ExpectedOutput]], dataset_name: str
self,
examples: Iterable[Example[Input, ExpectedOutput]],
dataset_name: str,
id: str | None = None,
) -> Dataset:
raise NotImplementedError

Expand Down
20 changes: 20 additions & 0 deletions tests/evaluation/test_dataset_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,26 @@ def file_dataset_repository(tmp_path: Path) -> FileDatasetRepository:
]


@mark.parametrize(
"repository_fixture",
test_repository_fixtures,
)
def test_dataset_repository_with_custom_id(
repository_fixture: str,
request: FixtureRequest,
dummy_string_example: Example[DummyStringInput, DummyStringOutput],
) -> None:
dataset_repository: DatasetRepository = request.getfixturevalue(repository_fixture)

dataset = dataset_repository.create_dataset(
examples=[dummy_string_example],
dataset_name="test-dataset",
id="my-custom-dataset-id",
)

assert dataset.id == "my-custom-dataset-id"


@mark.parametrize(
"repository_fixture",
test_repository_fixtures,
Expand Down

0 comments on commit b631794

Please sign in to comment.