Skip to content

Commit

Permalink
added openai arxiv dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
magdalendobson committed Aug 7, 2024
1 parent c206f33 commit a764cfe
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions benchmark/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,44 @@ def get_dataset(self):
def distance(self):
return "ip"

class OpenAIArXivDataset(DatasetCompetitionFormat):
def __init__(self, nb=2321096):
self.nb = nb
self.d = 1536
self.nq = 20000
self.dtype = "float32"
self.ds_fn = "openai_base.bin"
self.qs_fn = "openai_query.bin"
self.gt_fn = (
"openai-2M" if self.nb == 2321096 else
"openai-100K" if self.nb == 100000 else
None
)
self.basedir = os.path.join(BASEDIR, "OpenAIArXiv")
self.base_url = "https://comp21storage.z5.web.core.windows.net/arxiv-openaiv2-2M"

self.private_qs_url = None
self.private_gt_url = None

def prepare(self, skip_data=False, original_size=2321096):
return super().prepare(skip_data, 2321096)

def get_dataset_fn(self):
fn = os.path.join(self.basedir, self.ds_fn)
if self.nb != 2321096:
fn += '.crop_nb_%d' % self.nb
if os.path.exists(fn):
return fn
else:
raise RuntimeError("file %s not found" %fn)

def get_dataset(self):
slice = next(self.get_dataset_iterator(bs=self.nb))
return sanitize(slice)

def distance(self):
return "euclidean"

class RandomClusteredDS(DatasetCompetitionFormat):
def __init__(self, basedir="random-clustered"):
self.nb = 10000
Expand Down Expand Up @@ -1264,6 +1302,9 @@ def short_name(self):
'msmarco-10M': lambda : MSMarcoWebSearchDataset(10000000),
'msmarco-1M': lambda : MSMarcoWebSearchDataset(1000000),

'openai-2M': lambda : OpenAIArXivDataset(2321096),
'openai-100K': lambda : OpenAIArXivDataset(100000),

'random-xs': lambda : RandomDS(10000, 1000, 20),
'random-s': lambda : RandomDS(100000, 1000, 50),

Expand Down

0 comments on commit a764cfe

Please sign in to comment.