Skip to content

Commit

Permalink
Merge pull request #180 from SubstraFoundation/improve-fetch-models
Browse files Browse the repository at this point in the history
Improve fetch models error handling and content stream
  • Loading branch information
Kelvin-M authored Mar 5, 2020
2 parents d52acd1 + edeb7cc commit f57b97e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 23 deletions.
47 changes: 27 additions & 20 deletions backend/substrapp/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from os import path
import json
from multiprocessing.managers import BaseManager
from threading import Thread
import logging
import tarfile

Expand All @@ -27,7 +26,7 @@
from substrapp.ledger_utils import (log_start_tuple, log_success_tuple, log_fail_tuple,
query_tuples, LedgerError, LedgerStatusError, get_object_from_ledger)
from substrapp.tasks.utils import (ResourcesManager, compute_docker, get_asset_content, get_and_put_asset_content,
list_files, get_k8s_client, do_not_raise, timeit)
list_files, get_k8s_client, do_not_raise, timeit, ExceptionThread)
from substrapp.tasks.exception_handler import compute_error_code

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -199,24 +198,40 @@ def fetch_model(parent_tuple_type, authorized_types, input_model, directory):
raise TasksError(f'Traintuple: invalid input model: type={tuple_type}')


def prepare_traintuple_input_models(directory, tuple_):
"""Get traintuple input models content."""
input_models = tuple_.get('inModels')
if not input_models:
return

authorized_types = (AGGREGATETUPLE_TYPE, TRAINTUPLE_TYPE)
def fetch_models(tuple_type, authorized_types, input_models, directory):

models = []

for input_model in input_models:
proc = Thread(target=fetch_model,
args=(TRAINTUPLE_TYPE, authorized_types, input_model, directory))
proc = ExceptionThread(target=fetch_model,
args=(tuple_type, authorized_types, input_model, directory))
models.append(proc)
proc.start()

for proc in models:
proc.join()

exceptions = []

for proc in models:
if hasattr(proc, "_exception"):
exceptions.append(proc._exception)
logger.exception(proc._exception)
else:
if exceptions:
raise Exception(exceptions)


def prepare_traintuple_input_models(directory, tuple_):
"""Get traintuple input models content."""
input_models = tuple_.get('inModels')
if not input_models:
return

authorized_types = (AGGREGATETUPLE_TYPE, TRAINTUPLE_TYPE)

fetch_models(TRAINTUPLE_TYPE, authorized_types, input_models, directory)


def prepare_aggregatetuple_input_models(directory, tuple_):
"""Get aggregatetuple input models content."""
Expand All @@ -225,16 +240,8 @@ def prepare_aggregatetuple_input_models(directory, tuple_):
return

authorized_types = (AGGREGATETUPLE_TYPE, TRAINTUPLE_TYPE, COMPOSITE_TRAINTUPLE_TYPE)
models = []

for input_model in input_models:
proc = Thread(target=fetch_model,
args=(AGGREGATETUPLE_TYPE, authorized_types, input_model, directory))
models.append(proc)
proc.start()

for proc in models:
proc.join()
fetch_models(AGGREGATETUPLE_TYPE, authorized_types, input_models, directory)


def prepare_composite_traintuple_input_models(directory, tuple_):
Expand Down
15 changes: 15 additions & 0 deletions backend/substrapp/tasks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,18 @@ def wrapper(*args, **kwargs):
except Exception as e:
logging.exception(e)
return wrapper


class ExceptionThread(threading.Thread):

def run(self):
try:
if self._target:
self._target(*self._args, **self._kwargs)
except BaseException as e:
self._exception = e
raise e
finally:
# Avoid a refcycle if the thread is running a function with
# an argument that has a member that points to the thread.
del self._target, self._args, self._kwargs
9 changes: 6 additions & 3 deletions backend/substrapp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,12 @@ def get_remote_file(url, auth, content_dst_path=None, **kwargs):
try:
if kwargs.get('stream', False) and content_dst_path is not None:
chunk_size = 1024 * 1024
with open(content_dst_path, 'wb') as fp:
response = requests.get(url, **kwargs)
fp.writelines(response.iter_content(chunk_size))

with requests.get(url, **kwargs) as response:
response.raise_for_status()

with open(content_dst_path, 'wb') as fp:
fp.writelines(response.iter_content(chunk_size))
else:
response = requests.get(url, **kwargs)
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
Expand Down

0 comments on commit f57b97e

Please sign in to comment.