Skip to content

Commit

Permalink
[ENH] partially validate local node JSON config (#43)
Browse files Browse the repository at this point in the history
* Add warning and catch for invalid JSON

* Add jsonschema and validation

* Enable partial node matching

* Refactored validation to do file level first

Co-authored-by: Alyssa Dai <[email protected]>

* added new test case

Co-authored-by: Alyssa Dai <[email protected]>

* Better error handling and more comments

Co-authored-by: Alyssa Dai <[email protected]>

---------

Co-authored-by: Alyssa Dai <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 14, 2023
1 parent 2d267b4 commit 1d76389
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 27 deletions.
117 changes: 93 additions & 24 deletions app/api/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,39 @@
from pathlib import Path

import httpx
import jsonschema
from fastapi import HTTPException
from jsonschema import validate

LOCAL_NODE_INDEX_PATH = Path(__file__).parents[2] / "local_nb_nodes.json"
FEDERATION_NODES = {}

# We use this schema to validate the local_nb_nodes.json file
# We allow both array type input and a single JSON object
# Therefore the schema supports both
LOCAL_NODE_SCHEMA = {
"$schema": "http://json-schema.org/draft-07/schema#",
"definitions": {
"node": {
"type": "object",
"properties": {
"ApiURL": {"type": "string", "pattern": "^(http|https)://"},
"NodeName": {"type": "string"},
},
"required": ["ApiURL", "NodeName"],
"additionalProperties": False,
}
},
"oneOf": [
{
"type": "array",
"items": {"$ref": "#/definitions/node"},
"minItems": 1,
},
{"$ref": "#/definitions/node"},
],
}


def add_trailing_slash(url: str) -> str:
"""Add trailing slash to a URL if it does not already have one."""
Expand All @@ -20,41 +48,76 @@ def add_trailing_slash(url: str) -> str:

def parse_nodes_as_dict(path: Path) -> dict:
"""
Reads names and URLs of user-defined Neurobagel nodes from a JSON file (if available) and stores them in a dict
Reads names and URLs of user-defined Neurobagel nodes from a JSON file
(if available) and stores them in a dict
where the keys are the node URLs, and the values are the node names.
Makes sure node URLs end with a slash.
Makes sure node URLs end with a slash and only valid nodes are returned.
"""
# TODO: Add more validation of input JSON, including for JSONDecodeError (invalid JSON)
if path.exists() and path.stat().st_size > 0:
with open(path, "r") as f:
local_nodes = json.load(f)
if local_nodes:
if isinstance(local_nodes, list):
return {
add_trailing_slash(node["ApiURL"]): node["NodeName"]
for node in local_nodes
}
return {
add_trailing_slash(local_nodes["ApiURL"]): local_nodes[
"NodeName"
]
}
try:
with open(path, "r") as f:
local_nodes = json.load(f)
except json.JSONDecodeError:
warnings.warn(f"You provided an invalid JSON file at {path}.")
local_nodes = []

# We wrap our input in a list if it isn't already to enable
# easy iteration for adding trailing slashes, even though our
# file level schema could handle a single non-array input
input_nodes = (
local_nodes if isinstance(local_nodes, list) else [local_nodes]
)

try:
# We validate the entire file first, checking all nodes together
validate(instance=input_nodes, schema=LOCAL_NODE_SCHEMA)
valid_nodes = input_nodes
except jsonschema.ValidationError:
valid_nodes = []
invalid_nodes = []
for node in input_nodes:
try:
validate(
instance=node,
schema=LOCAL_NODE_SCHEMA["definitions"]["node"],
)
valid_nodes.append(node)
except jsonschema.ValidationError:
invalid_nodes.append(node)

if invalid_nodes:
warnings.warn(
"Some of the nodes in the JSON are invalid:\n"
f"{json.dumps(invalid_nodes, indent=2)}"
)

if valid_nodes:
return {
add_trailing_slash(node["ApiURL"]): node["NodeName"]
for node in valid_nodes
}

return {}


async def create_federation_node_index():
"""
Creates an index of nodes for federation, which is a dict where the keys are the node URLs, and the values are the node names.
Fetches the names and URLs of public Neurobagel nodes from a remote directory file, and combines them with the user-defined local nodes.
Creates an index of nodes for federation, which is a dict
where the keys are the node URLs, and the values are the node names.
Fetches the names and URLs of public Neurobagel nodes from a remote
directory file, and combines them with the user-defined local nodes.
"""
node_directory_url = "https://raw.githubusercontent.com/neurobagel/menu/main/node_directory/neurobagel_public_nodes.json"
local_nodes = parse_nodes_as_dict(LOCAL_NODE_INDEX_PATH)

if not local_nodes:
warnings.warn(
f"No local Neurobagel nodes defined or found. Federation will be limited to nodes available from the Neurobagel public node directory {node_directory_url}. "
"(To specify one or more local nodes to federate over, define them in a 'local_nb_nodes.json' file in the current directory and relaunch the API.)\n"
"No local Neurobagel nodes defined or found. Federation "
" will be limited to nodes available from the "
f"Neurobagel public node directory {node_directory_url}. "
"(To specify one or more local nodes to federate over, "
"define them in a 'local_nb_nodes.json' file in the "
"current directory and relaunch the API.)\n"
)

node_directory_response = httpx.get(
Expand Down Expand Up @@ -84,8 +147,10 @@ async def create_federation_node_index():
else:
warnings.warn(failed_get_warning)
raise RuntimeError(
"No local or public Neurobagel nodes available for federation. "
"Please define at least one local node in a 'local_nb_nodes.json' file in the current directory and try again."
"No local or public Neurobagel nodes available for federation."
"Please define at least one local node in "
"a 'local_nb_nodes.json' file in the "
"current directory and try again."
)

# This step will remove any duplicate keys from the local and public node dicts, giving priority to the local nodes.
Expand Down Expand Up @@ -114,7 +179,10 @@ def check_nodes_are_recognized(node_urls: list):


def validate_query_node_url_list(node_urls: list) -> list:
"""Format and validate node URLs passed as values to the query endpoint, including setting a default list of node URLs when none are provided."""
"""
Format and validate node URLs passed as values to the query endpoint,
including setting a default list of node URLs when none are provided.
"""
# Remove and ignore node URLs that are empty strings
node_urls = list(filter(None, node_urls))
if node_urls:
Expand Down Expand Up @@ -155,7 +223,8 @@ def send_get_request(url: str, params: list):
params=params,
# TODO: Revisit timeout value when query performance is improved
timeout=30.0,
# Enable redirect following (off by default) so APIs behind a proxy can be reached
# Enable redirect following (off by default) so
# APIs behind a proxy can be reached
follow_redirects=True,
)

Expand Down
12 changes: 10 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
anyio==3.6.2
attrs==22.1.0
attrs==23.1.0
certifi==2023.7.22
cfgv==3.3.1
coverage==7.0.0
click==8.1.3
coverage==7.0.0
distlib==0.3.6
exceptiongroup==1.0.4
fastapi==0.95.2
Expand All @@ -14,7 +14,10 @@ httpx==0.23.1
identify==2.5.9
idna==3.4
iniconfig==1.1.1
jsonschema==4.20.0
jsonschema-specifications==2023.11.2
nodeenv==1.7.0
numpy==1.26.2
orjson==3.8.6
packaging==21.3
pandas==1.5.2
Expand All @@ -24,8 +27,13 @@ pre-commit==2.20.0
pydantic==1.10.2
pyparsing==3.0.9
pytest==7.2.0
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML==6.0
referencing==0.31.1
rfc3986==1.5.0
rpds-py==0.13.2
six==1.16.0
sniffio==1.3.0
starlette==0.27.0
toml==0.10.2
Expand Down
78 changes: 77 additions & 1 deletion tests/test_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
],
)
def test_add_trailing_slash(url, expected_url):
"""Test that a trailing slash is added to a URL if it does not already have one."""
"""
Test that a trailing slash is added to a
URL if it does not already have one.
"""
assert util.add_trailing_slash(url) == expected_url


Expand All @@ -28,6 +31,15 @@ def test_add_trailing_slash(url, expected_url):
},
{"http://firstnode.neurobagel.org/query/": "firstnode"},
),
(
[
{
"ApiURL": "http://firstnode.neurobagel.org/query",
"NodeName": "firstnode",
}
],
{"http://firstnode.neurobagel.org/query/": "firstnode"},
),
(
[
{
Expand Down Expand Up @@ -144,3 +156,67 @@ def test_validate_query_node_url_list(
)

assert util.validate_query_node_url_list(raw_url_list) == expected_url_list


@pytest.mark.parametrize(
"set_nodes,expected_nodes",
[
(
{
"IMakeMyOwnRules": "http://firstnode.neurobagel.org/query",
"WhatAreSchemas": "firstnode",
},
{},
),
(
{
"ApiURL": "this.is.not.a.url",
"NodeName": "firstnode",
},
{},
),
(
[
{
"ApiURL": "https://firstnode.neurobagel.org/query/",
"NodeName": "firstnode",
},
{
"ApiURL": "invalidurl",
"NodeName": "secondnode",
},
],
{
"https://firstnode.neurobagel.org/query/": "firstnode",
},
),
],
)
def test_schema_invalid_nodes_raise_warning(
set_nodes, expected_nodes, tmp_path
):
"""
If the JSON is valid but parts of the schema are invalid, expect to raise a warning
and only return the parts that fit the schema.
"""
# TODO: split this test into the warning and the output
# First create a temporary input config file for the test to read
with open(tmp_path / "local_nb_nodes.json", "w") as f:
f.write(json.dumps(set_nodes, indent=2))

with pytest.warns(
UserWarning, match=r"Some of the nodes in the JSON are invalid.*"
):
nodes = util.parse_nodes_as_dict(tmp_path / "local_nb_nodes.json")

assert nodes == expected_nodes


def test_invalid_json_raises_warning(tmp_path):
"""Ensure that an invalid JSON file raises a warning but doesn't crash the app."""

with open(tmp_path / "local_nb_nodes.json", "w") as f:
f.write("this is not valid JSON")

with pytest.warns(UserWarning, match="You provided an invalid JSON"):
util.parse_nodes_as_dict(tmp_path / "local_nb_nodes.json")

0 comments on commit 1d76389

Please sign in to comment.