Skip to content

Commit

Permalink
feat: fix asset download
Browse files Browse the repository at this point in the history
  • Loading branch information
theodu committed Aug 28, 2023
1 parent 89e8894 commit c1dd6e4
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 24 deletions.
9 changes: 8 additions & 1 deletion src/kili/gateways/kili_api_gateway/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

from kili.core.graphql.graphql_client import GraphQLClient
from kili.exceptions import NotFound
from kili.gateways.kili_api_gateway.project.operations import get_project_query
from kili.gateways.kili_api_gateway.project.types import ProjectWhere
from kili.gateways.kili_api_gateway.queries import fragment_builder
Expand All @@ -26,4 +27,10 @@ def get_project(
result = self.graphql_client.execute(
query=query, variables={"where": where.build_gql_value()}
)
return result["data"]
projects = result["data"]
if len(projects) == 0:
raise NotFound(
f"project ID: {project_id}. Maybe your KILI_API_KEY does not belong to a member of"
" the project."
)
return projects[0]
2 changes: 1 addition & 1 deletion src/kili/services/copy_project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def download_and_upload_assets(assets):

def _download_assets(self, from_project_id, fields, tmp_dir, assets):
download_function, _ = get_download_assets_function(
self.kili,
self.kili.kili_api_gateway,
download_media=True,
fields=fields,
project_id=from_project_id,
Expand Down
2 changes: 1 addition & 1 deletion src/kili/services/export/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def fetch_assets(

options = QueryOptions(disable_tqdm=disable_tqdm)
post_call_function, fields = get_download_assets_function(
kili, download_media, fields, project_id, local_media_dir
kili.kili_api_gateway, download_media, fields, project_id, local_media_dir
)
assets = list(
AssetQuery(kili.graphql_client, kili.http_client)(
Expand Down
2 changes: 1 addition & 1 deletion src/kili/use_cases/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def list_assets(
validate_category_search_query(where.label_category_search)

post_call_function, fields = get_download_assets_function(
self, download_media, fields, where.project_id, local_media_dir
self._kili_api_gateway, download_media, fields, where.project_id, local_media_dir
)
assets_gen = self._kili_api_gateway.list_assets(fields, where, options, post_call_function)

Expand Down
9 changes: 5 additions & 4 deletions src/kili/use_cases/asset/asset_label_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

def parse_labels_of_asset(asset: Dict, project: LabelParsingProject) -> Dict:
"""Parse the labels of an asset queried to Kili."""
if asset.get("labels", {}).get("jsonResponse") is not None:
asset["labels"] = parse_labels(
asset["labels"], project["jsonInterface"], project["inputType"]
)
if asset.get("labels"):
if asset["labels"][0].get("jsonResponse") is not None:
asset["labels"] = parse_labels(
asset["labels"], project["jsonInterface"], project["inputType"]
)
if asset.get("latestLabel", {}).get("jsonResponse") is not None:
asset["latestLabel"] = ParsedLabel(
label=asset["latestLabel"],
Expand Down
26 changes: 10 additions & 16 deletions src/kili/use_cases/asset/media_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
DataConnectionsQuery,
DataConnectionsWhere,
)
from kili.core.graphql.operations.project.queries import ProjectQuery, ProjectWhere
from kili.exceptions import NotFound
from kili.gateways.kili_api_gateway import KiliAPIGateway

from ...entrypoints.queries.asset.exceptions import (
DownloadNotAllowedError,
Expand All @@ -27,7 +26,7 @@


def get_download_assets_function(
kili,
kili_api_gateway: KiliAPIGateway,
download_media: bool,
fields: List[str],
project_id: str,
Expand All @@ -40,23 +39,19 @@ def get_download_assets_function(
Also returns the fields to be queried, which may be modified
if the jsonContent field is necessary.
"""
print("get_download_assets_function")
if not download_media:
return None, fields

projects = list(
ProjectQuery(kili.graphql_client, kili.http_client)(
ProjectWhere(project_id=project_id), ["inputType"], QueryOptions(disable_tqdm=True)
)
)
if len(projects) == 0:
raise NotFound(
f"project ID: {project_id}. Maybe your KILI_API_KEY does not belong to a member of the"
" project."
)
project = kili_api_gateway.get_project(project_id=project_id, fields=["inputType"])
input_type = project["inputType"]
print(project, download_media, local_media_dir)

# We need to query the data connections to know if the assets are hosted in a cloud storage
# If so, we remove the fields "content" and "jsonContent" from the query
data_connections_gen = DataConnectionsQuery(kili.graphql_client, kili.http_client)(
data_connections_gen = DataConnectionsQuery(
kili_api_gateway.graphql_client, kili_api_gateway.http_client
)(
where=DataConnectionsWhere(project_id=project_id),
fields=["id"],
options=QueryOptions(disable_tqdm=True, first=1, skip=0),
Expand All @@ -67,7 +62,6 @@ def get_download_assets_function(
" Asset download is disabled."
)

input_type = projects[0]["inputType"]
jsoncontent_field_added = False
if input_type in ("TEXT", "VIDEO") and "jsonContent" not in fields:
fields = fields + ["jsonContent"]
Expand All @@ -79,7 +73,7 @@ def get_download_assets_function(
project_id,
jsoncontent_field_added,
input_type,
kili.http_client,
kili_api_gateway.http_client,
).download_assets,
fields,
)
Expand Down
125 changes: 125 additions & 0 deletions tests/integration/use_cases/test_asset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from pathlib import Path
from tempfile import TemporaryDirectory

from kili.gateways.kili_api_gateway import KiliAPIGateway
from kili.gateways.kili_api_gateway.asset.types import AssetWhere
from kili.gateways.kili_api_gateway.queries import QueryOptions
from kili.use_cases.asset import AssetUseCases
from kili.utils.labels.parsing import ParsedLabel


def test_given_query_parameters_I_can_query_assets(kili_api_gateway: KiliAPIGateway):
# mocking
nb_assets = 200
assets = [{"id": "asset_id"}] * nb_assets
kili_api_gateway.list_assets.return_value = assets

# given parameters to query assets
asset_use_cases = AssetUseCases(kili_api_gateway)
where = AssetWhere(project_id="project_id")
fields = ["id"]
options = QueryOptions(disable_tqdm=False)

# when creating query assets
asset_gen = asset_use_cases.list_assets(
where,
fields,
options,
download_media=False,
local_media_dir=None,
label_output_format="dict",
)

# then
assert list(asset_gen) == assets


def test_given_query_parameters_I_can_query_assets_and_get_their_labels_parsed(
kili_api_gateway: KiliAPIGateway,
):
# mocking
json_response = {"JOB_0": {"categories": [{"name": "CATGORY_A"}]}}
asset = {
"id": "asset_id",
"labels": [{"jsonResponse": json_response}],
"latestLabel": {"jsonResponse": json_response},
}
json_interface = {
"jobs": {
"JOB_0": {
"mlTask": "CLASSIFICATION",
"isChild": False,
"content": {
"categories": {
"CATGORY_A": {"children": [], "name": "category A", "id": "category30"},
},
"input": "checkbox",
},
}
}
}
kili_api_gateway.list_assets.return_value = (asset for asset in [asset])
kili_api_gateway.get_project.return_value = {
"jsonInterface": json_interface,
"inputType": "TEXT",
}

# given parameters to query assets
asset_use_cases = AssetUseCases(kili_api_gateway)
where = AssetWhere(project_id="project_id")
fields = ["id", "label.jsonResponse", "latestLabel.jsonresponse"]
options = QueryOptions(disable_tqdm=False)

# when creating query assets
asset_gen = asset_use_cases.list_assets(
where,
fields,
options,
download_media=False,
local_media_dir=None,
label_output_format="parsed_label",
)
returned_asset = list(asset_gen)[0]

# then
assert returned_asset == asset
assert isinstance(returned_asset["latestLabel"], ParsedLabel)
assert isinstance(returned_asset["labels"][0], ParsedLabel)


def test_given_query_parameters_I_can_query_assets_and_download_their_media(
kili_api_gateway: KiliAPIGateway,
):
# mocking
asset = {
"id": "asset_id",
"content": "http://test.jpg",
"externalId": "external_id.jpg",
}
kili_api_gateway.list_assets.return_value = (asset for asset in [asset])

# given parameters to query assets
asset_use_cases = AssetUseCases(kili_api_gateway)
where = AssetWhere(project_id="project_id")
fields = ["id", "content"]
options = QueryOptions(disable_tqdm=False)

# when creating query assets
with TemporaryDirectory() as tmp_dir:
asset_gen = asset_use_cases.list_assets(
where,
fields,
options,
download_media=True,
local_media_dir=tmp_dir,
label_output_format="dict",
)
returned_asset = list(asset_gen)[0]

# then
expected_file_path = Path(tmp_dir) / "external_id.jpg"
assert returned_asset == {
"id": "asset_id",
"content": str(expected_file_path),
}
assert expected_file_path.is_file()

0 comments on commit c1dd6e4

Please sign in to comment.