Skip to content

Commit

Permalink
Merge pull request #7 from CybercentreCanada/hotfix/retries
Browse files Browse the repository at this point in the history
hotfix/retries
  • Loading branch information
cccs-sgaron authored Feb 22, 2021
2 parents 5e6b2f7 + 7d61380 commit 29e0c61
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions assemblyline_service_client/task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,17 @@ def cleanup_working_directory(self, folder_path):
except Exception:
pass

def request_with_retries(self, method: str, url: str, **kwargs):
def request_with_retries(self, method: str, url: str, max_retry=None, **kwargs):
if 'headers' in kwargs:
self.session.headers.update(kwargs['headers'])
kwargs.pop('headers')
header_dump = '; '.join(f"{k}={v}" for k, v in self.session.headers.items())
self.log.debug('query headers: ' + header_dump)

back_off_time = 1
retry = 0

while True:
while max_retry is None or retry < max_retry:
back_off_time = min(2 ** (retry - 5), 8)
try:
func = getattr(self.session, method)
resp = func(url, **kwargs)
Expand All @@ -195,7 +196,11 @@ def request_with_retries(self, method: str, url: str, **kwargs):

return resp.json()['api_response']
except requests.ConnectionError:
self.log.warning(f"Cannot reach service server. Retrying after {back_off_time}s.")
msg = f"Cannot reach service server. Retrying after {back_off_time}s."
if retry < 2:
self.log.info(msg)
else:
self.log.warning(msg)
time.sleep(back_off_time)
except requests.Timeout: # Handles ConnectTimeout and ReadTimeout
time.sleep(back_off_time)
Expand All @@ -206,7 +211,9 @@ def request_with_retries(self, method: str, url: str, **kwargs):
self.log.error(str(e))
raise

back_off_time = min(back_off_time*2, 8)
retry += 1

return None

def try_run(self):
self.initialize_service()
Expand Down Expand Up @@ -340,13 +347,10 @@ def get_task(self) -> ServiceTask:
def download_file(self, sha256, sid) -> Optional[str]:
self.status = STATUSES.DOWNLOADING_FILE
received_file_sha256 = ''
retry = 0
file_path = None
self.log.info(f"[{sid}] Downloading file: {sha256}")
while received_file_sha256 != sha256 and retry < 3:
r = self.session.get(self._path('file', sha256), headers=self.headers)
retry += 1
# self.log.info(str(r.ok))
r = self.request_with_retries('get', self._path('file', sha256), max_retry=3, headers=self.headers)
if r is not None:
if r.status_code == 404:
self.log.error(f"[{sid}] Requested file not found in the system: {sha256}")
return None
Expand Down

0 comments on commit 29e0c61

Please sign in to comment.