diff --git a/benchmark/datasets.py b/benchmark/datasets.py index 39223ff3..91a07ba8 100644 --- a/benchmark/datasets.py +++ b/benchmark/datasets.py @@ -619,6 +619,44 @@ def get_dataset(self): def distance(self): return "ip" +class OpenAIArXivDataset(DatasetCompetitionFormat): + def __init__(self, nb=2321096): + self.nb = nb + self.d = 1536 + self.nq = 20000 + self.dtype = "float32" + self.ds_fn = "openai_base.bin" + self.qs_fn = "openai_query.bin" + self.gt_fn = ( + "openai-2M" if self.nb == 2321096 else + "openai-100K" if self.nb == 100000 else + None + ) + self.basedir = os.path.join(BASEDIR, "OpenAIArXiv") + self.base_url = "https://comp21storage.z5.web.core.windows.net/arxiv-openaiv2-2M" + + self.private_qs_url = None + self.private_gt_url = None + + def prepare(self, skip_data=False, original_size=2321096): + return super().prepare(skip_data, 2321096) + + def get_dataset_fn(self): + fn = os.path.join(self.basedir, self.ds_fn) + if self.nb != 2321096: + fn += '.crop_nb_%d' % self.nb + if os.path.exists(fn): + return fn + else: + raise RuntimeError("file %s not found" %fn) + + def get_dataset(self): + slice = next(self.get_dataset_iterator(bs=self.nb)) + return sanitize(slice) + + def distance(self): + return "euclidean" + class RandomClusteredDS(DatasetCompetitionFormat): def __init__(self, basedir="random-clustered"): self.nb = 10000 @@ -1264,6 +1302,9 @@ def short_name(self): 'msmarco-10M': lambda : MSMarcoWebSearchDataset(10000000), 'msmarco-1M': lambda : MSMarcoWebSearchDataset(1000000), + 'openai-2M': lambda : OpenAIArXivDataset(2321096), + 'openai-100K': lambda : OpenAIArXivDataset(100000), + 'random-xs': lambda : RandomDS(10000, 1000, 20), 'random-s': lambda : RandomDS(100000, 1000, 50),