Skip to content

Commit

Permalink
Merge pull request #1177 from ZiyiXia/master
Browse files Browse the repository at this point in the history
Docstring of abc
  • Loading branch information
ZiyiXia authored Oct 31, 2024
2 parents dd7d32b + 7ae0ecf commit d76e51c
Show file tree
Hide file tree
Showing 9 changed files with 400 additions and 12 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,9 @@ pic2.py
.pyre/

# MacOS associated
.DS_Store
.DS_Store

# results
en_results
zh_results
docs
6 changes: 6 additions & 0 deletions FlagEmbedding/abc/evaluation/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

@dataclass
class AbsEvalArgs:
"""
Base class for evaluation arguments.
"""
eval_name: str = field(
default=None,
metadata={"help": "The name of the evaluation task, such as msmarco, beir, miracl, etc."}
Expand Down Expand Up @@ -77,6 +80,9 @@ class AbsEvalArgs:

@dataclass
class AbsEvalModelArgs:
"""
Base class for model arguments during evaluation.
"""
embedder_name_or_path: str = field(
metadata={"help": "The embedder name or path.", "required": True}
)
Expand Down
179 changes: 179 additions & 0 deletions FlagEmbedding/abc/evaluation/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@


class AbsEvalDataLoader(ABC):
"""
Base class of data loader for evaluation.
Args:
eval_name (str): The experiment name of current evaluation.
dataset_dir (str, optional): path to the datasets. Defaults to None.
cache_dir (str, optional): Path to HuggingFace cache directory. Defaults to None.
token (str, optional): HF_TOKEN to access the private datasets/models in HF. Defaults to None.
force_redownload: If True, will force redownload the dataset to cover the local dataset. Defaults to False.
"""
def __init__(
self,
eval_name: str,
Expand Down Expand Up @@ -43,6 +53,17 @@ def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
pass

def check_dataset_names(self, dataset_names: Union[str, List[str]]) -> List[str]:
"""Check the validity of dataset names
Args:
dataset_names (Union[str, List[str]]): a dataset name (str) or a list of dataset names (List[str])
Raises:
ValueError
Returns:
List[str]: List of valid dataset names.
"""
available_dataset_names = self.available_dataset_names()
if isinstance(dataset_names, str):
dataset_names = [dataset_names]
Expand All @@ -53,6 +74,15 @@ def check_dataset_names(self, dataset_names: Union[str, List[str]]) -> List[str]
return dataset_names

def check_splits(self, splits: Union[str, List[str]], dataset_name: Optional[str] = None) -> List[str]:
"""Check whether the splits are available in the dataset.
Args:
splits (Union[str, List[str]]): Splits to check.
dataset_name (Optional[str], optional): Name of dataset to check. Defaults to None.
Returns:
List[str]: The available splits.
"""
available_splits = self.available_splits(dataset_name=dataset_name)
if isinstance(splits, str):
splits = [splits]
Expand All @@ -65,6 +95,14 @@ def check_splits(self, splits: Union[str, List[str]], dataset_name: Optional[str
return checked_splits

def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDict:
"""Load the corpus from the dataset.
Args:
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
Returns:
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
"""
if self.dataset_dir is not None:
if dataset_name is None:
save_dir = self.dataset_dir
Expand All @@ -75,6 +113,18 @@ def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDic
return self._load_remote_corpus(dataset_name=dataset_name)

def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
"""Load the corpus from the dataset.
Args:
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
split (str, optional): The split to load relevance from. Defaults to 'test'.
Raises:
ValueError
Returns:
datasets.DatasetDict: A dict of relevance of query and document.
"""
if self.dataset_dir is not None:
if dataset_name is None:
save_dir = self.dataset_dir
Expand All @@ -91,6 +141,18 @@ def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') ->
return self._load_remote_qrels(dataset_name=dataset_name, split=split)

def load_queries(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
"""Load the queries from the dataset.
Args:
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
split (str, optional): The split to load queries from. Defaults to 'test'.
Raises:
ValueError
Returns:
datasets.DatasetDict: A dict of queries with id as key, query text as value.
"""
if self.dataset_dir is not None:
if dataset_name is None:
save_dir = self.dataset_dir
Expand All @@ -111,6 +173,18 @@ def _load_remote_corpus(
dataset_name: Optional[str] = None,
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Abstract method to load corpus from remote dataset, to be overrode in child class.
Args:
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
save_dir (Optional[str], optional): Path to save the new downloaded corpus. Defaults to None.
Raises:
NotImplementedError: Loading remote corpus is not implemented.
Returns:
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
"""
raise NotImplementedError("Loading remote corpus is not implemented.")

def _load_remote_qrels(
Expand All @@ -119,6 +193,19 @@ def _load_remote_qrels(
split: str = 'test',
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Abstract method to load relevance from remote dataset, to be overrode in child class.
Args:
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
split (str, optional): Split to load from the remote dataset. Defaults to 'test'.
save_dir (Optional[str], optional): Path to save the new downloaded relevance. Defaults to None.
Raises:
NotImplementedError: Loading remote qrels is not implemented.
Returns:
datasets.DatasetDict: A dict of relevance of query and document.
"""
raise NotImplementedError("Loading remote qrels is not implemented.")

def _load_remote_queries(
Expand All @@ -127,9 +214,31 @@ def _load_remote_queries(
split: str = 'test',
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Abstract method to load queries from remote dataset, to be overrode in child class.
Args:
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
split (str, optional): Split to load from the remote dataset. Defaults to 'test'.
save_dir (Optional[str], optional): Path to save the new downloaded queries. Defaults to None.
Raises:
NotImplementedError
Returns:
datasets.DatasetDict: A dict of queries with id as key, query text as value.
"""
raise NotImplementedError("Loading remote queries is not implemented.")

def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None) -> datasets.DatasetDict:
"""Load corpus from local dataset.
Args:
save_dir (str): Path to save the loaded corpus.
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
Returns:
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
"""
corpus_path = os.path.join(save_dir, 'corpus.jsonl')
if self.force_redownload or not os.path.exists(corpus_path):
logger.warning(f"Corpus not found in {corpus_path}. Trying to download the corpus from the remote and save it to {save_dir}.")
Expand All @@ -144,6 +253,19 @@ def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None)
return datasets.DatasetDict(corpus)

def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
"""Load relevance from local dataset.
Args:
save_dir (str): Path to save the loaded relevance.
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
split (str, optional): Split to load from the local dataset. Defaults to 'test'.
Raises:
ValueError
Returns:
datasets.DatasetDict: A dict of relevance of query and document.
"""
checked_split = self.check_splits(split)
if len(checked_split) == 0:
raise ValueError(f"Split {split} not found in the dataset.")
Expand All @@ -166,6 +288,19 @@ def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, s
return datasets.DatasetDict(qrels)

def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
"""Load queries from local dataset.
Args:
save_dir (str): Path to save the loaded queries.
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
split (str, optional): Split to load from the local dataset. Defaults to 'test'.
Raises:
ValueError
Returns:
datasets.DatasetDict: A dict of queries with id as key, query text as value.
"""
checked_split = self.check_splits(split)
if len(checked_split) == 0:
raise ValueError(f"Split {split} not found in the dataset.")
Expand All @@ -182,6 +317,18 @@ def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None,
return datasets.DatasetDict(queries)

def _download_file(self, download_url: str, save_dir: str):
"""Download file from provided URL.
Args:
download_url (str): Source URL of the file.
save_dir (str): Path to the directory to save the zip file.
Raises:
FileNotFoundError
Returns:
str: The path of the downloaded file.
"""
save_path = os.path.join(save_dir, download_url.split('/')[-1])

if self.force_redownload or (not os.path.exists(save_path) or os.path.getsize(save_path) == 0):
Expand All @@ -201,6 +348,14 @@ def _download_file(self, download_url: str, save_dir: str):
return save_path

def _get_fpath_size(self, fpath: str) -> int:
"""Get the total size of the files in provided path.
Args:
fpath (str): path of files to compute the size.
Returns:
int: The total size in bytes.
"""
if not os.path.isdir(fpath):
return os.path.getsize(fpath)
else:
Expand All @@ -212,6 +367,18 @@ def _get_fpath_size(self, fpath: str) -> int:
return total_size

def _download_gz_file(self, download_url: str, save_dir: str):
"""Download and unzip the gzip file from provided URL.
Args:
download_url (str): Source URL of the gzip file.
save_dir (str): Path to the directory to save the gzip file.
Raises:
FileNotFoundError: _description_
Returns:
str: The path to the file after unzip.
"""
gz_file_path = self._download_file(download_url, save_dir)
cmd = ["gzip", "-d", gz_file_path]
try:
Expand All @@ -226,6 +393,18 @@ def _download_gz_file(self, download_url: str, save_dir: str):
return file_path

def _download_zip_file(self, download_url: str, save_dir: str):
"""Download and unzip the zip file from provided URL.
Args:
download_url (str): Source URL of the zip file.
save_dir (str): Path to the directory to save the zip file.
Raises:
FileNotFoundError
Returns:
str: The path to the file after unzip.
"""
zip_file_path = self._download_file(download_url, save_dir)
file_path = zip_file_path.replace(".zip", "")
if self.force_redownload or not os.path.exists(file_path):
Expand Down
Loading

0 comments on commit d76e51c

Please sign in to comment.