Skip to content

Commit

Permalink
cleanup: arango_datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed Dec 27, 2023
1 parent 33f7df2 commit 8191545
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 140 deletions.
1 change: 1 addition & 0 deletions arango_datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from arango_datasets.datasets import Datasets # noqa: F401
316 changes: 176 additions & 140 deletions arango_datasets/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import json
import sys
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List, Optional

import requests
from arango.collection import StandardCollection
from arango.database import Database
from arango.exceptions import CollectionCreateError, DocumentInsertError
from requests import ConnectionError, HTTPError

from .utils import progress

Expand All @@ -16,12 +13,14 @@ class Datasets:
:param db: A python-arango database instance
:type db: arango.database.Database
:param batch_size:
Optional batch size supplied to python-arango import_bulk function
:param batch_size: Optional batch size supplied to the
python-arango `import_bulk` function. Defaults to 50.
:type batch_size: int
:param metadata_file: Optional URL for datasets metadata file
:param metadata_file: URL for datasets metadata file. Defaults to
"https://arangodb-dataset-library-ml.s3.amazonaws.com/root_metadata.json".
:type metadata_file: str
:param preserve_existing: Boolean to preserve existing data and graph definiton
:param preserve_existing: Whether to preserve the existing collections and
graph of the dataset (if any). Defaults to False.
type preserve_existing: bool
"""

Expand All @@ -32,150 +31,187 @@ def __init__(
metadata_file: str = "https://arangodb-dataset-library-ml.s3.amazonaws.com/root_metadata.json", # noqa: E501
preserve_existing: bool = False,
):
self.metadata_file: str = metadata_file
self.metadata_contents: Dict[str, Any]
self.batch_size = batch_size
self.user_db = db
self.preserve_existing = preserve_existing
self.file_type: str
if issubclass(type(db), Database) is False:
if not isinstance(db, Database):
msg = "**db** parameter must inherit from arango.database.Database"
raise TypeError(msg)

try:
response = requests.get(self.metadata_file, timeout=6000)
response.raise_for_status()
self.metadata_contents = response.json()
except (HTTPError, ConnectionError) as e:
print("Unable to retrieve metadata information.")
print(e)
raise
self.user_db = db
self.batch_size = batch_size
self.metadata_file = metadata_file
self.preserve_existing = preserve_existing

self.labels = []
for label in self.metadata_contents:
self.labels.append(label)
self.__metadata: Dict[str, Dict[str, Any]]
self.__metadata = self.__get_response(self.metadata_file).json()
self.__dataset_names = [n for n in self.__metadata]

def list_datasets(self) -> List[str]:
print(self.labels)
return self.labels
"""List available datasets
def dataset_info(self, dataset_name: str) -> Dict[str, Any]:
for i in self.metadata_contents[str(dataset_name).upper()]:
print(f"{i}: {self.metadata_contents[str(dataset_name).upper()][i]} ")
print("")
return self.metadata_contents
:return: Names of the available datasets to load.
:rtype: List[str]
"""
print(self.__dataset_names)
return self.__dataset_names

def insert_docs(
def dataset_info(self, dataset_name: str) -> Dict[str, Any]:
"""Get information about a dataset
:param dataset_name: Name of the dataset.
:type dataset_name: str
:return: Some metadata about the dataset.
:rtype: Dict[str, Any]
:raises ValueError: If the dataset is not found.
"""
if dataset_name.upper() not in self.__dataset_names:
raise ValueError(f"Dataset '{dataset_name}' not found")

info: Dict[str, Any] = self.__metadata[dataset_name.upper()]
print(info)
return info

def load(
self,
collection: StandardCollection,
docs: List[Dict[Any, Any]],
collection_name: str,
dataset_name: str,
batch_size: Optional[int] = None,
preserve_existing: Optional[bool] = None,
) -> None:
try:
with progress(f"Collection: {collection_name}") as p:
p.add_task("insert_docs")
"""Load a dataset into the database.
:param dataset_name: Name of the dataset.
:type dataset_name: str
:param batch_size: Batch size supplied to the
python-arango `import_bulk` function. Overrides the **batch_size**
supplied to the constructor. Defaults to None.
:type batch_size: Optional[int]
:param preserve_existing: Whether to preserve the existing collections and
graph of the dataset (if any). Overrides the **preserve_existing**
supplied to the constructor. Defaults to False.
:type preserve_existing: bool
:raises ValueError: If the dataset is not found.
"""
if dataset_name.upper() not in self.__dataset_names:
raise ValueError(f"Dataset '{dataset_name}' not found")

dataset_contents = self.__metadata[dataset_name.upper()]

# Backwards compatibility
self.batch_size = batch_size if batch_size is not None else self.batch_size
self.preserve_existing = (
preserve_existing
if preserve_existing is not None
else self.preserve_existing
)

file_type = dataset_contents["file_type"]
load_function: Callable[[str], Any]
if file_type == "json":
load_function = self.__load_json
elif file_type == "jsonl":
load_function = self.__load_jsonl
else:
raise ValueError(f"Unsupported file type: {file_type}")

for data, is_edge in [
(dataset_contents["edges"], True),
(dataset_contents["vertices"], False),
]:
for col_data in data:
col = self.__initialize_collection(col_data["collection_name"], is_edge)

for file in col_data["files"]:
self.__import_bulk(col, load_function(file))

if edge_definitions := dataset_contents.get("edge_definitions"):
graph_name = dataset_contents["graph_name"]
self.user_db.delete_graph(graph_name, ignore_missing=True)
self.user_db.create_graph(graph_name, edge_definitions)

def __get_response(self, url: str, timeout: int = 60) -> requests.Response:
"""Wrapper around requests.get() with a progress bar.
:param url: URL to get a response from.
:type url: str
:param timeout: Timeout in seconds. Defaults to 60.
:type timeout: int
:raises ConnectionError: If the connection fails.
:raises HTTPError: If the HTTP request fails.
:return: The response from the URL.
:rtype: requests.Response
"""
with progress(f"GET: {url}") as p:
p.add_task("get_response")

response = requests.get(url, timeout=timeout)
response.raise_for_status()
return response

def __initialize_collection(
self, collection_name: str, is_edge: bool
) -> StandardCollection:
"""Initialize a collection.
:param collection_name: Name of the collection.
:type collection_name: str
:param is_edge: Whether the collection is an edge collection.
:type is_edge: bool
:raises CollectionCreateError: If the collection cannot be created.
:return: The collection.
:rtype: arango.collection.StandardCollection
"""
if self.preserve_existing is False:
m = f"Collection '{collection_name}' already exists, dropping and creating with new data." # noqa: E501
print(m)

self.user_db.delete_collection(collection_name, ignore_missing=True)

return self.user_db.create_collection(collection_name, edge=is_edge)

def __load_json(self, file_url: str) -> List[Dict[str, Any]]:
"""Load a JSON file into memory.
:param file_url: URL of the JSON file.
:type file_url: str
:raises ConnectionError: If the connection fails.
:raises HTTPError: If the HTTP request fails.
:return: The JSON data.
:rtype: Dict[str, Any]
"""
json_data: List[Dict[str, Any]] = self.__get_response(file_url).json()
return json_data

def __load_jsonl(self, file_url: str) -> List[Dict[str, Any]]:
"""Load a JSONL file into memory.
:param file_url: URL of the JSONL file.
:type file_url: str
:raises ConnectionError: If the connection fails.
:raises HTTPError: If the HTTP request fails.
:return: The JSONL data as a list of dictionaries.
"""
json_data = []
data = self.__get_response(file_url)

collection.import_bulk(docs, batch_size=self.batch_size)
if data.encoding is None:
data.encoding = "utf-8"

except DocumentInsertError as exec:
print("Document insertion failed due to the following error:")
print(exec.message)
sys.exit(1)
for line in data.iter_lines(decode_unicode=True):
if line:
json_data.append(json.loads(line))

print(f"Finished loading current file for collection: {collection_name}")
return json_data

def load_json(
self,
collection_name: str,
edge_type: bool,
file_url: str,
collection: StandardCollection,
) -> None:
try:
with progress(f"Downloading file for: {collection_name}") as p:
p.add_task("load_file")
data = requests.get(file_url, timeout=6000).json()
except (HTTPError, ConnectionError) as e:
print("Unable to download file.")
print(e)
raise e
print(f"Downloaded file for: {collection_name}, now importing... ")
self.insert_docs(collection, data, collection_name)

def load_jsonl(
self,
collection_name: str,
edge_type: bool,
file_url: str,
collection: StandardCollection,
def __import_bulk(
self, collection: StandardCollection, docs: List[Dict[str, Any]]
) -> None:
json_data = []
try:
with progress(f"Downloading file for: {collection_name}") as p:
p.add_task("load_file")
data = requests.get(file_url, timeout=6000)

if data.encoding is None:
data.encoding = "utf-8"

for line in data.iter_lines(decode_unicode=True):
if line:
json_data.append(json.loads(line))

except (HTTPError, ConnectionError) as e:
print("Unable to download file.")
print(e)
raise
print(f"Downloaded file for: {collection_name}, now importing... ")
self.insert_docs(collection, json_data, collection_name)

def load_file(self, collection_name: str, edge_type: bool, file_url: str) -> None:
collection: StandardCollection
try:
collection = self.user_db.create_collection(collection_name, edge=edge_type)
except CollectionCreateError as exec:
print(
f"""Failed to create {collection_name} collection due
to the following error:"""
)
print(exec.error_message)
sys.exit(1)
if self.file_type == "json":
self.load_json(collection_name, edge_type, file_url, collection)
elif self.file_type == "jsonl":
self.load_jsonl(collection_name, edge_type, file_url, collection)
else:
raise ValueError(f"Unsupported file type: {self.file_type}")

def cleanup_collections(self, collection_name: str) -> None:
if (
self.user_db.has_collection(collection_name)
and self.preserve_existing is False
):
print(
f"""
Old collection found
${collection_name},
dropping and creating with new data."""
)
self.user_db.delete_collection(collection_name)

def load(self, dataset_name: str) -> None:
if str(dataset_name).upper() in self.labels:
self.file_type = self.metadata_contents[str(dataset_name).upper()][
"file_type"
]

for edge in self.metadata_contents[str(dataset_name).upper()]["edges"]:
self.cleanup_collections(collection_name=edge["collection_name"])
for e in edge["files"]:
self.load_file(edge["collection_name"], True, e)

for vertex in self.metadata_contents[str(dataset_name).upper()]["vertices"]:
self.cleanup_collections(collection_name=vertex["collection_name"])
for v in vertex["files"]:
self.load_file(vertex["collection_name"], False, v)

else:
print(f"Dataset `{str(dataset_name.upper())}` not found")
sys.exit(1)
"""Wrapper around python-arango's import_bulk() with a progress bar.
:param collection: The collection to insert the documents into.
:type collection: arango.collection.StandardCollection
:param docs: The documents to insert.
:type docs: List[Dict[Any, Any]]
:raises DocumentInsertError: If the document cannot be inserted.
"""
with progress(f"Collection: {collection.name}") as p:
p.add_task("insert_docs")

collection.import_bulk(docs, batch_size=self.batch_size)

0 comments on commit 8191545

Please sign in to comment.