Skip to content

Commit

Permalink
Merged PR 225: fix bugs with cli and websockets
Browse files Browse the repository at this point in the history
-fix cli requests import bug by adding it to cli section of pyproject.toml
-move qrcodet requirement from remote to cli part of pyproject.toml
-fix url error with blocks endpoint by removing trailing slash
-change the websocket connection to only happen once get_data_api is called not when the sdk object is created
-fix small bug where if validate_token was false they sdk would error out because there is no expiry time
  • Loading branch information
Spencer Vecile committed Mar 1, 2024
2 parents 988384b + 2b81e4c commit 8fa33b1
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 105 deletions.
119 changes: 29 additions & 90 deletions sdk/atriumdb/atrium_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,10 @@ class AtriumSDK:
:param str api_url: Specifies the URL of the server hosting the API in "api" connection type.
:param str token: An authorization token for the API in "api" connection type.
:param str refresh_token: A token to refresh your authorization token if it expires while you are doing something. Only for the API in "api" connection type.
:param bool validate_token: Do you want the sdk to check if your token is valid when the sdk object is created. If it is not valid it will attempt to use the refresh token to get you a new one. Only for "api" connection type.
:param bool validate_token: Do you want the sdk to check if your token is valid when the sdk object is created and during execution. If it is not valid it will attempt to use the refresh token to get you a new one. If false the sdk will not attempt to refresh your token at any point. Only for "api" connection type.
:param str tsc_file_location: A file path pointing to the directory in which the TSC (time series compression) files are written for this dataset. Used to customize the TSC directory location, rather than using `dataset_location/tsc`.
:param str atriumdb_lib_path: A file path pointing to the shared library (CDLL) that powers the compression and decompression. Not required for most users.
:param bool no_pool: If true disables Mariadb connection pooling, instead using a new connection for each query.
Examples:
-----------
Expand Down Expand Up @@ -130,7 +131,8 @@ class AtriumSDK:

def __init__(self, dataset_location: Union[str, PurePath] = None, metadata_connection_type: str = None,
connection_params: dict = None, num_threads: int = None, api_url: str = None, token: str = None,
refresh_token=None, validate_token=True, tsc_file_location: str = None, atriumdb_lib_path: str = None, no_pool=False):
refresh_token=None, validate_token=True, tsc_file_location: str = None, atriumdb_lib_path: str = None,
no_pool=False):

self.dataset_location = dataset_location

Expand Down Expand Up @@ -214,10 +216,11 @@ def __init__(self, dataset_location: Union[str, PurePath] = None, metadata_conne
elif metadata_connection_type == 'api':
# Check if the necessary modules are installed for API connections
if not REQUESTS_INSTALLED:
raise ImportError("Must install requests, python-dotenv, websockets and PyJWT[crypto] or simply atriumdb[remote].")
raise ImportError("Remote mode not installed. Please install atriumdb with pip install atriumdb[remote]")

self.mode = "api"
self.api_url = api_url
self.validate_token = validate_token
# remove the leading http stuff and replace it with ws, also remove any trailing slashes
self.ws_url = api_url.replace("https://", "wss://").replace("http://", "ws://").rstrip('/')
# make this variable so once connection is made in the thread it is available to the sdk object
Expand All @@ -243,7 +246,7 @@ def __init__(self, dataset_location: Union[str, PurePath] = None, metadata_conne

self.token, self.refresh_token = token, refresh_token

if validate_token:
if self.validate_token:
# send get request to the atriumdb api to get the info you need to validate the API token
auth_config_response = requests.get(f'{self.api_url}/auth/cli/code')

Expand All @@ -260,11 +263,6 @@ def __init__(self, dataset_location: Union[str, PurePath] = None, metadata_conne
except jwt.PyJWTError:
# if the token is invalid attempt to refresh it
self._refresh_token()

# need to check incase the token was just refreshed and the connection already made
if self.websock_conn is None:
# connect to the websocket
self._websocket_connect()
else:
raise ValueError("metadata_connection_type must be one of sqlite, mysql, mariadb or api")

Expand Down Expand Up @@ -306,9 +304,8 @@ def __init__(self, dataset_location: Union[str, PurePath] = None, metadata_conne
atexit.register(self.close)

@classmethod
def create_dataset(cls, dataset_location: Union[str, PurePath], database_type: str = None,
protected_mode: str = None, overwrite: str = None, connection_params: dict = None,
no_pool=False):
def create_dataset(cls, dataset_location: Union[str, PurePath], database_type: str = None, protected_mode: str = None,
overwrite: str = None, connection_params: dict = None, no_pool=False):
"""
.. _create_dataset_label:
Expand All @@ -319,6 +316,7 @@ def create_dataset(cls, dataset_location: Union[str, PurePath], database_type: s
:param str protected_mode: Specifies the protection mode of the metadata database. Allowed values are "True" or "False". If "True", data deletion will not be allowed. If "False", data deletion will be allowed. The default behavior can be changed in the `sdk/atriumdb/helpers/config.toml` file.
:param str overwrite: Specifies the behavior to take when new data being inserted overlaps in time with existing data. Allowed values are "error", "ignore", or "overwrite". Upon triggered overwrite: if "error", an error will be raised. If "ignore", the new data will not be inserted. If "overwrite", the old data will be overwritten with the new data. The default behavior can be changed in the `sdk/atriumdb/helpers/config.toml` file.
:param dict connection_params: A dictionary containing connection parameters for "mysql" or "mariadb" database type. It should contain keys for 'host', 'user', 'password', 'database', and 'port'.
:param bool no_pool: If true disables Mariadb connection pooling, instead using a new connection for each query.
:return: An initialized AtriumSDK object.
:rtype: AtriumSDK
Expand Down Expand Up @@ -428,7 +426,8 @@ def get_data(self, measure_id: int = None, start_time_n: int = None, end_time_n:
raise ValueError("Invalid time units. Expected one of: %s" % time_unit_options)

# make sure time type is either 1 or 2
assert time_type in ALLOWED_TIME_TYPES, "Time type must be in [1, 2]"
if time_type not in ALLOWED_TIME_TYPES:
raise ValueError("Time type must be in [1, 2]")

# convert start and end time to nanoseconds
start_time_n = int(start_time_n * time_unit_options[time_units])
Expand Down Expand Up @@ -556,39 +555,11 @@ def get_data_from_blocks(self, block_list, filename_dict, start_time_n, end_time
def _get_data_api(self, measure_id: int, start_time_n: int, end_time_n: int, device_id: int = None,
patient_id: int = None, mrn: int = None, time_type=1, analog=True, sort=True,
allow_duplicates=True):
"""
.. _get_data_api_label:
Retrieve data from the API for a specific measure within a given time range, and optionally for a specific device,
patient or medical record number (MRN). This function is automatically called by get_data when in "api" mode.
:param measure_id: The ID of the measure to retrieve data for.
:param start_time_n: The start time (in nanoseconds) to retrieve data from.
:param end_time_n: The end time (in nanoseconds) to retrieve data until. The end time is not
inclusive so if you want the end time to be included you have to add one sample period to it.
:param device_id: (Optional) The ID of the device to retrieve data for.
:param patient_id: (Optional) The ID of the patient to retrieve data for.
:param mrn: (Optional) The medical record number (MRN) to retrieve data for.
:param time_type: The time type returned to you. Time_type=1 is time stamps which is what most people will
want. Time_type=2 is gap array and should only be used by advanced users. Note that sorting will not work for
time type 2 and you may receive more values than you asked for because of this.
:param analog: Convert digitized data to its analog values (default: True).
:param bool sort: Whether to sort the returned data. If false you may receive more data than just
[start_time_n:end_time_n).
:param bool allow_duplicates: Whether to allow duplicate times in the sorted returned data if they exist. Does
nothing if sort is false.
:return: A tuple containing headers, request times, and request values.
Example usage:
>>> headers, r_times, r_values = _get_data_api(1, 0, 1000000000, device_id=123)
"""

params = {'start_time': start_time_n, 'end_time': end_time_n, 'measure_id': measure_id, 'device_id': device_id,
'patient_id': patient_id, 'mrn': mrn}
# Request the block information
block_info_list = self._request("GET", 'sdk/blocks/', params=params)
block_info_list = self._request("GET", 'sdk/blocks', params=params)

# Check if there are no blocks in the response
if len(block_info_list) == 0:
Expand All @@ -599,7 +570,7 @@ def _get_data_api(self, measure_id: int, start_time_n: int, end_time_n: int, dev
num_bytes_list = [row['num_bytes'] for row in block_info_list]

# tik = time.perf_counter()
encoded_bytes = self.block_websocket_request(block_info_list)
encoded_bytes = self._block_websocket_request(block_info_list)
# print(f"Time for {len(block_info_list)} blocks over websocket: {round((time.perf_counter() - tik) * 1000, 4)} ms")
# print(f"Mb/s is {(np.sum(num_bytes_list)/(time.perf_counter() - tik))/1_000_000}\n")\

Expand All @@ -613,21 +584,18 @@ def _get_data_api(self, measure_id: int, start_time_n: int, end_time_n: int, dev

return headers, r_times, r_values

def block_websocket_request(self, block_info_list):
"""
Get block bytes using a websocket for multiple requests.
:param block_info_list: A list of dictionaries containing block information, such as block ID.
:return: A list of block bytes.
"""
if not REQUESTS_INSTALLED:
raise ImportError("websockets module is not installed.")
def _block_websocket_request(self, block_info_list):

# check if the api token will expire within 30 seconds and if so refresh it
if time.time() >= self.token_expiry - 30:
if self.validate_token and time.time() >= self.token_expiry - 30:
# get new API token
self._refresh_token()

# If there is no websocket connection create it now
if self.websock_conn is None:
# connect to the websocket
self._websocket_connect()

# make a comma delimited string of all the blocks we want from the API
block_ids = ','.join([str(row['id']) for row in block_info_list])
self.websock_conn.send(block_ids)
Expand Down Expand Up @@ -2996,17 +2964,7 @@ def get_all_label_name_descendents(self, label_name_id: int = None, name: str =
return self._build_descendants_tree(label_name_id, descendants, max_depth)

def _build_descendants_tree(self, root_id, descendants, max_depth, current_depth=0):
"""
Recursive helper method to build the nested dictionary of descendants.
:param int root_id: The root ID of the current subtree.
:param list descendants: List of all descendants.
:param int max_depth: The maximum depth of the tree.
:param int current_depth: The current depth in the tree.

:return: A nested dictionary for the current subtree.
:rtype: dict
"""
if max_depth is not None and current_depth >= max_depth:
return {}

Expand Down Expand Up @@ -3504,32 +3462,12 @@ def get_source_info(self, source_id: int) -> dict | None:
return None

def _request(self, method: str, endpoint: str, **kwargs):
"""
Send an API request using the specified method and endpoint.
This method checks if the `requests` module is installed, and then sends the API request
using the provided method and endpoint.
:param method: The HTTP method to use for the request (e.g., 'GET', 'POST', etc.).
:type method: str
:param endpoint: The API endpoint to send the request to (e.g., '/users').
:type endpoint: str
:param kwargs: Additional keyword arguments to pass to the `requests.request` function.
:raises ImportError: If the `requests` module is not installed.
:raises ValueError: If the API request returns a non-200 status code.
:return: The JSON response from the API request.
:rtype: dict
"""

# Check if the `requests` module is installed.
if not REQUESTS_INSTALLED:
raise ImportError("requests module is not installed.")

# Construct the full URL by combining the base API URL and the endpoint.
url = f"{self.api_url.rstrip('/')}/{endpoint.lstrip('/')}"

# check if the api token will expire within 30 seconds and if so refresh it
if time.time() >= self.token_expiry-30:
if self.validate_token and time.time() >= self.token_expiry - 30:
# get new API token
self._refresh_token()

Expand All @@ -3549,7 +3487,8 @@ def _request(self, method: str, endpoint: str, **kwargs):

def _websocket_connect(self):
def conn():
self.websock_conn = connect(f"{self.ws_url}/sdk/blocks/ws", compression=None, max_size=None, additional_headers={"Authorization": "Bearer {}".format(self.token)})
self.websock_conn = connect(f"{self.ws_url}/sdk/blocks/ws", compression=None, max_size=None,
additional_headers={"Authorization": "Bearer {}".format(self.token)})

# The websockets lib uses a thread to receive messages. Normally you would call close() after receiving the
# messages to shut that thread down but since we are reducing overhead we want to keep that connection open
Expand All @@ -3569,6 +3508,7 @@ def _refresh_token(self):
if self.websock_conn is not None:
# close old websocket connection
self.websock_conn.close()
self.websock_conn = None

# send request to Auth0 to refresh your API token using your refresh token
token_payload = {'grant_type': 'refresh_token', 'client_id': self.auth_config['auth0_client_id'], 'refresh_token': self.refresh_token}
Expand All @@ -3587,15 +3527,14 @@ def _refresh_token(self):
decoded_token = _validate_bearer_token(self.token, self.auth_config)
self.token_expiry = decoded_token['exp']

# make sure the user is using a .env file to store the token
# if the user is using a .env file to store the token
if self.dot_env_loaded:
# change the api token in the .env file and the atriumdb class attribute to the new token
# change the api token in the .env file
set_key("./.env", "ATRIUMDB_API_TOKEN", token_data['access_token'])
# load the new environment variables into the OS
load_dotenv(dotenv_path="./.env", override=True)

# reconnect to the API websocket with the new token
self._websocket_connect()
_LOGGER.debug("Expired token refreshed")

def close(self):
"""
Expand All @@ -3606,7 +3545,7 @@ def close(self):
# make sure we are in api mode and if we are close the connection
if self.mode == "api" and self.websock_conn is not None:
self.websock_conn.close()
print("Websocket connection closed")
_LOGGER.debug("Websocket connection closed")
# if we are in sql mode and there is a connection pool close it
elif (self.metadata_connection_type == "mariadb" or self.metadata_connection_type == "mysql") and self.sql_handler.pool is not None:
self.sql_handler.pool.close()
Expand Down
14 changes: 3 additions & 11 deletions sdk/atriumdb/cli/atriumdb_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,21 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import click
from tabulate import tabulate
from pathlib import Path

import requests
import qrcodeT
import os
import time
import logging
from tabulate import tabulate
from pathlib import Path
from urllib.parse import urlparse

from dotenv import load_dotenv, set_key, get_key

from atriumdb.atrium_sdk import AtriumSDK
from atriumdb.windowing.definition import DatasetDefinition
from atriumdb.adb_functions import parse_metadata_uri
from atriumdb.cli.sdk import get_sdk_from_env_vars, create_sdk_from_env_vars
from atriumdb.sql_handler.sql_constants import SUPPORTED_DB_TYPES
from atriumdb.transfer.cohort.cohort_yaml import parse_atrium_cohort_yaml
from atriumdb.transfer.adb.dataset import transfer_data
from atriumdb.transfer.formats.dataset import export_dataset, import_dataset
from atriumdb.transfer.formats.export_data import export_data_from_sdk
from atriumdb.transfer.formats.import_data import import_data_to_sdk

import logging

_LOGGER = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion sdk/docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
project = 'AtriumDB'
copyright = '2024, The Hospital for Sick Children'
author = 'LaussenLabs'
release = '2.2.1'
release = '2.2.2'

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
6 changes: 3 additions & 3 deletions sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mypkg = ["*.so", "*.dll"]

[project]
name = "atriumdb"
version = "2.2.1"
version = "2.2.2"
description = "Timeseries Database"
readme = "README.md"
authors = [{name = "Robert Greer, William Dixon, Spencer Vecile"}, { name = "Robert Greer", email = "[email protected]"}, { name = "William Dixon", email = "[email protected]" }, { name = "Spencer Vecile", email = "[email protected]"}]
Expand Down Expand Up @@ -49,11 +49,12 @@ mariadb = [
remote = [
"requests >= 2.28.2, < 3",
"PyJWT[crypto] >= 2.8.0, < 3",
"qrcodeT >= 1.0.4, < 2",
"python-dotenv >= 0.21, < 1",
"websockets >= 12.0, < 13",
]
cli = [
"requests >= 2.28.2, < 3",
"qrcodeT >= 1.0.4, < 2",
"click >= 8.1.3, < 9",
"pandas >= 1.5, < 2",
"tabulate >= 0.9.0, < 1",
Expand All @@ -72,7 +73,6 @@ all = [
"pandas >= 1.5, < 2",
"tabulate >= 0.9.0, < 1",
"fastparquet == 2023.2.0",
"PyYAML >= 6.0"
]

[dev-dependencies]
Expand Down

0 comments on commit 8fa33b1

Please sign in to comment.