From 7cfba26fe1415a25946fdd3e78b7d7b6c01d7973 Mon Sep 17 00:00:00 2001 From: davelopez <46503462+davelopez@users.noreply.github.com> Date: Wed, 29 May 2024 16:27:00 +0200 Subject: [PATCH] Fix Invenio credentials handling Only ask for token when is really required --- lib/galaxy/files/sources/_rdm.py | 7 +------ lib/galaxy/files/sources/invenio.py | 31 +++++++++++++++++------------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/lib/galaxy/files/sources/_rdm.py b/lib/galaxy/files/sources/_rdm.py index 14f7e9e1daa0..25e33ed8b757 100644 --- a/lib/galaxy/files/sources/_rdm.py +++ b/lib/galaxy/files/sources/_rdm.py @@ -7,7 +7,6 @@ from typing_extensions import Unpack -from galaxy.exceptions import AuthenticationRequired from galaxy.files import ProvidesUserFileSourcesUserContext from galaxy.files.sources import ( BaseFilesSource, @@ -193,15 +192,11 @@ def _serialization_props(self, user_context: OptionalUserContext = None): effective_props[key] = self._evaluate_prop(val, user_context=user_context) return effective_props - def get_authorization_token(self, user_context: OptionalUserContext) -> str: + def get_authorization_token(self, user_context: OptionalUserContext) -> Optional[str]: token = None if user_context: effective_props = self._serialization_props(user_context) token = effective_props.get("token") - if not token: - raise AuthenticationRequired( - f"Please provide a personal access token in your user's preferences for '{self.label}'" - ) return token def get_public_name(self, user_context: OptionalUserContext) -> Optional[str]: diff --git a/lib/galaxy/files/sources/invenio.py b/lib/galaxy/files/sources/invenio.py index 921438a446bd..99ec58f1e161 100644 --- a/lib/galaxy/files/sources/invenio.py +++ b/lib/galaxy/files/sources/invenio.py @@ -217,12 +217,7 @@ def create_draft_record( }, } - headers = self._get_request_headers(user_context) - if "Authorization" not in headers: - raise Exception( - "Cannot create record without authentication token. Please set your personal access token in your Galaxy preferences." - ) - + headers = self._get_request_headers(user_context, auth_required=True) response = requests.post(self.records_url, json=create_record_request, headers=headers) self._ensure_response_has_expected_status_code(response, 201) record = response.json() @@ -238,7 +233,7 @@ def upload_file_to_draft_record( ): record = self._get_draft_record(record_id, user_context=user_context) upload_file_url = record["links"]["files"] - headers = self._get_request_headers(user_context) + headers = self._get_request_headers(user_context, auth_required=True) # Add file metadata entry response = requests.post(upload_file_url, json=[{"key": filename}], headers=headers) @@ -394,28 +389,38 @@ def _get_creator_from_public_name(self, public_name: Optional[str] = None) -> Cr } def _get_response( - self, user_context: OptionalUserContext, request_url: str, params: Optional[Dict[str, Any]] = None + self, + user_context: OptionalUserContext, + request_url: str, + params: Optional[Dict[str, Any]] = None, + auth_required: bool = False, ) -> dict: - headers = self._get_request_headers(user_context) + headers = self._get_request_headers(user_context, auth_required) response = requests.get(request_url, params=params, headers=headers) self._ensure_response_has_expected_status_code(response, 200) return response.json() - def _get_request_headers(self, user_context: OptionalUserContext): + def _get_request_headers(self, user_context: OptionalUserContext, auth_required: bool = False): token = self.plugin.get_authorization_token(user_context) headers = {"Authorization": f"Bearer {token}"} if token else {} + if auth_required and token is None: + self._raise_auth_required() return headers def _ensure_response_has_expected_status_code(self, response, expected_status_code: int): - if response.status_code == 403: - record_url = response.url.replace("/api", "").replace("/files", "") - raise AuthenticationRequired(f"Please make sure you have the necessary permissions to access: {record_url}") if response.status_code != expected_status_code: + if response.status_code == 403: + self._raise_auth_required() error_message = self._get_response_error_message(response) raise Exception( f"Request to {response.url} failed with status code {response.status_code}: {error_message}" ) + def _raise_auth_required(self): + raise AuthenticationRequired( + f"Please provide a personal access token in your user's preferences for '{self.plugin.label}'" + ) + def _get_response_error_message(self, response): response_json = response.json() error_message = response_json.get("message") if response.status_code == 400 else response.text