Skip to content

Commit

Permalink
Add nodes/{node_id}/download endpoint to REST API (#74)
Browse files Browse the repository at this point in the history
We implement the download of nodes that implement the export
functionality as in aiida-core. We use StreamingResponse to send the
data in byte chunks of size 1024.

The tests require the client to be asynchronously to prevent I/O
operations on closed file handlers.

We require a numpy dependency because we need a node type that
implemented the export functionality. The only node type that is not
bounded to the domain of material science is ArrayData.
  • Loading branch information
agoscinski authored Nov 21, 2024
1 parent e06f267 commit 366f036
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 3 deletions.
3 changes: 3 additions & 0 deletions aiida_restapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
'disabled': False,
}
}

# The chunks size for streaming data for download
DOWNLOAD_CHUNK_SIZE = 1024
50 changes: 48 additions & 2 deletions aiida_restapi/routers/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
import os
import tempfile
from pathlib import Path
from typing import Any, List, Optional
from typing import Any, Generator, List, Optional

from aiida import orm
from aiida.cmdline.utils.decorators import with_dbenv
from aiida.common.exceptions import EntryPointError
from aiida.common.exceptions import EntryPointError, LicensingException, NotExistent
from aiida.plugins.entry_point import load_entry_point
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from pydantic import ValidationError

from aiida_restapi import models, resources
from aiida_restapi.config import DOWNLOAD_CHUNK_SIZE

from .auth import get_current_active_user

Expand Down Expand Up @@ -41,6 +43,50 @@ async def get_nodes_download_formats() -> dict[str, Any]:
return resources.get_all_download_formats()


@router.get('/nodes/{nodes_id}/download')
@with_dbenv()
async def download_node(nodes_id: int, download_format: Optional[str] = None) -> StreamingResponse:
"""Get nodes by id."""
from aiida.orm import load_node

try:
node = load_node(nodes_id)
except NotExistent:
raise HTTPException(status_code=404, detail=f'Could no find any node with id {nodes_id}')

if download_format is None:
raise HTTPException(
status_code=422,
detail='Please specify the download format. '
'The available download formats can be '
'queried using the /nodes/download_formats/ endpoint.',
)

elif download_format in node.get_export_formats():
# byteobj, dict with {filename: filecontent}
import io

try:
exported_bytes, _ = node._exportcontent(download_format)
except LicensingException as exc:
raise HTTPException(status_code=500, detail=str(exc))

def stream() -> Generator[bytes, None, None]:
with io.BytesIO(exported_bytes) as handler:
while chunk := handler.read(DOWNLOAD_CHUNK_SIZE):
yield chunk

return StreamingResponse(stream(), media_type=f'application/{download_format}')

else:
raise HTTPException(
status_code=422,
detail='The format {} is not supported. '
'The available download formats can be '
'queried using the /nodes/download_formats/ endpoint.'.format(download_format),
)


@router.get('/nodes/{nodes_id}', response_model=models.Node)
@with_dbenv()
async def read_node(nodes_id: int) -> Optional[models.Node]:
Expand Down
4 changes: 4 additions & 0 deletions docs/source/user_guide/graphql.md
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,10 @@ http://localhost:5000/api/v4/nodes/ffe11/repo/list
```html
http://localhost:5000/api/v4/nodes/ffe11/repo/contents?filename="aiida.in"
```


Not implemented for GraphQL, please use the REST API for this use case.

```html
http://localhost:5000/api/v4/nodes/fafdsf/download?download_format=xsf
```
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ testing = [
'pytest-regressions',
'pytest-cov',
'requests',
'httpx'
'httpx',
'numpy~=1.21'
]

[project.urls]
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
from typing import Any, Callable, Mapping, MutableMapping, Optional, Union

import numpy as np
import pytest
import pytz
from aiida import orm
Expand Down Expand Up @@ -164,6 +165,17 @@ def default_nodes():
return [node_1.pk, node_2.pk, node_3.pk, node_4.pk]


@pytest.fixture(scope='function')
def array_data_node():
"""Populate database with downloadable node (implmenting a _prepare_* function).
For testing the chunking of the streaming we create an array that needs to be splitted int two chunks."""

from aiida_restapi.config import DOWNLOAD_CHUNK_SIZE

nb_elements = DOWNLOAD_CHUNK_SIZE // 64 + 1
return orm.ArrayData(np.arange(nb_elements, dtype=np.int64)).store()


@pytest.fixture(scope='function')
def authenticate():
"""Authenticate user.
Expand Down
21 changes: 21 additions & 0 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,24 @@ def test_create_bool_with_extra(client, authenticate): # pylint: disable=unused
assert check_response.status_code == 200, response.content
assert check_response.json()['extras']['extra_one'] == 'value_1'
assert check_response.json()['extras']['extra_two'] == 'value_2'


@pytest.mark.anyio
async def test_get_download_node(array_data_node, async_client):
"""Test download node /nodes/{nodes_id}/download.
The async client is needed to avoid an error caused by an I/O operation on closed file"""

# Test that array is correctly downloaded as json
response = await async_client.get(f'/nodes/{array_data_node.pk}/download?download_format=json')
assert response.status_code == 200, response.json()
assert response.json().get('default', None) == array_data_node.get_array().tolist()

# Test exception when wrong download format given
response = await async_client.get(f'/nodes/{array_data_node.pk}/download?download_format=cif')
assert response.status_code == 422, response.json()
assert 'format cif is not supported' in response.json()['detail']

# Test exception when no download format given
response = await async_client.get(f'/nodes/{array_data_node.pk}/download')
assert response.status_code == 422, response.json()
assert 'Please specify the download format' in response.json()['detail']

0 comments on commit 366f036

Please sign in to comment.