Skip to content

Commit

Permalink
Merge branch 'hotfix/2.21.2'
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Feb 24, 2023
2 parents 5fbbf62 + cb88696 commit 98b96b0
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 152 deletions.
2 changes: 1 addition & 1 deletion ibllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import warnings

__version__ = '2.21.1'
__version__ = '2.21.2'
warnings.filterwarnings('always', category=DeprecationWarning, module='ibllib')

# if this becomes a full-blown library we should let the logging configuration to the discretion of the dev
Expand Down
72 changes: 13 additions & 59 deletions ibllib/oneibl/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,13 @@
from one.util import filter_datasets
from one.alf.files import add_uuid_string, session_path_parts
from iblutil.io.parquet import np2str
from ibllib.oneibl.registration import register_dataset
from ibllib.oneibl.registration import register_dataset, get_lab, get_local_data_repository
from ibllib.oneibl.patcher import FTPPatcher, SDSCPatcher, SDSC_ROOT_PATH, SDSC_PATCH_PATH


_logger = logging.getLogger(__name__)


def get_local_data_repository(one):
if one is None:
return

if not Path.home().joinpath(".globusonline/lta/client-id.txt").exists():
return

with open(Path.home().joinpath(".globusonline/lta/client-id.txt"), 'r') as fid:
globus_id = fid.read()

data_repo = one.alyx.rest('data-repository', 'list', globus_endpoint_id=globus_id)
if len(data_repo):
return [da['name'] for da in data_repo][0]


class DataHandler(abc.ABC):
def __init__(self, session_path, signature, one=None):
"""
Expand All @@ -47,10 +32,7 @@ def __init__(self, session_path, signature, one=None):
self.one = one

def setUp(self):
"""
Function to optionally overload to download required data to run task
:return:
"""
"""Function to optionally overload to download required data to run task."""
pass

def getData(self, one=None):
Expand Down Expand Up @@ -126,7 +108,7 @@ def uploadData(self, outputs, version, **kwargs):
:return: output info of registered datasets
"""
versions = super().uploadData(outputs, version)
data_repo = get_local_data_repository(self.one)
data_repo = get_local_data_repository(self.one.alyx)
return register_dataset(outputs, one=self.one, versions=versions, repository=data_repo, **kwargs)


Expand All @@ -136,7 +118,7 @@ def __init__(self, session_path, signatures, one=None):
Data handler for running tasks on lab local servers. Will download missing data from SDSC using Globus
:param session_path: path to session
:param signature: input and output file signatures
:param signatures: input and output file signatures
:param one: ONE instance
"""
from one.remote.globus import Globus, get_lab_from_endpoint_id # noqa
Expand All @@ -147,14 +129,7 @@ def __init__(self, session_path, signatures, one=None):
self.globus.endpoints['local']['root_path'] = '/mnt/s0/Data/Subjects'

# Find the lab
labs = get_lab_from_endpoint_id(alyx=self.one.alyx)

if len(labs) == 2:
# for flofer lab
subject = self.one.path2ref(self.session_path)['subject']
self.lab = self.one.alyx.rest('subjects', 'list', nickname=subject)[0]['lab']
else:
self.lab = labs[0]
self.lab = get_lab(self.session_path, self.one.alyx)

# For cortex lab we need to get the endpoint from the ibl alyx
if self.lab == 'cortexlab':
Expand All @@ -165,10 +140,7 @@ def __init__(self, session_path, signatures, one=None):
self.local_paths = []

def setUp(self):
"""
Function to download necessary data to run tasks using globus-sdk
:return:
"""
"""Function to download necessary data to run tasks using globus-sdk."""
if self.lab == 'cortexlab':
one = ONE(base_url='https://alyx.internationalbrainlab.org')
df = super().getData(one=one)
Expand Down Expand Up @@ -221,14 +193,11 @@ def uploadData(self, outputs, version, **kwargs):
:return: output info of registered datasets
"""
versions = super().uploadData(outputs, version)
data_repo = get_local_data_repository(self.one)
data_repo = get_local_data_repository(self.one.alyx)
return register_dataset(outputs, one=self.one, versions=versions, repository=data_repo, **kwargs)

def cleanUp(self):
"""
Clean up, remove the files that were downloaded from globus once task has completed
:return:
"""
"""Clean up, remove the files that were downloaded from globus once task has completed."""
for file in self.local_paths:
os.unlink(file)

Expand Down Expand Up @@ -280,10 +249,7 @@ def __init__(self, task, session_path, signature, one=None):
self.local_paths = []

def setUp(self):
"""
Function to download necessary data to run tasks using AWS boto3
:return:
"""
"""Function to download necessary data to run tasks using AWS boto3."""
df = super().getData()
self.local_paths = self.one._download_aws(map(lambda x: x[1], df.iterrows()))

Expand Down Expand Up @@ -362,10 +328,7 @@ def uploadData(self, outputs, version, **kwargs):
# versions=versions, **kwargs)

def cleanUp(self):
"""
Clean up, remove the files that were downloaded from globus once task has completed
:return:
"""
"""Clean up, remove the files that were downloaded from globus once task has completed."""
if self.task.status == 0:
for file in self.local_paths:
os.unlink(file)
Expand All @@ -383,10 +346,7 @@ def __init__(self, session_path, signature, one=None):
super().__init__(session_path, signature, one=one)

def setUp(self):
"""
Function to download necessary data to run tasks using globus
:return:
"""
"""Function to download necessary data to run tasks using globus."""
# TODO
pass

Expand Down Expand Up @@ -416,10 +376,7 @@ def __init__(self, task, session_path, signatures, one=None):
self.task = task

def setUp(self):
"""
Function to create symlinks to necessary data to run tasks
:return:
"""
"""Function to create symlinks to necessary data to run tasks."""
df = super().getData()

SDSC_TMP = Path(SDSC_PATCH_PATH.joinpath(self.task.__class__.__name__))
Expand Down Expand Up @@ -451,9 +408,6 @@ def uploadData(self, outputs, version, **kwargs):
return sdsc_patcher.patch_datasets(outputs, dry=False, versions=versions, **kwargs)

def cleanUp(self):
"""
Function to clean up symlinks created to run task
:return:
"""
"""Function to clean up symlinks created to run task."""
assert SDSC_PATCH_PATH.parts[0:4] == self.task.session_path.parts[0:4]
shutil.rmtree(self.task.session_path)
41 changes: 32 additions & 9 deletions ibllib/oneibl/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from pkg_resources import parse_version
from one.alf.files import get_session_path, folder_parts, get_alf_path
from one.registration import RegistrationClient, get_dataset_type
from one.remote.globus import get_local_endpoint_id
from one.remote.globus import get_local_endpoint_id, get_lab_from_endpoint_id
from one.webclient import AlyxClient
from one.converters import ConversionMixin
import one.alf.exceptions as alferr
from one.util import datasets2records, ensure_list

Expand Down Expand Up @@ -462,20 +464,41 @@ def get_local_data_repository(ac):
return next((da['name'] for da in data_repo), None)


def get_lab(ac):
def get_lab(session_path, alyx=None):
"""
Get list of associated labs from Globus client ID.
Get lab from a session path using the subject name.
On local lab servers, the lab name is not in the ALF path and the globus endpoint ID may be
associated with multiple labs, so lab name is fetched from the subjects endpoint.
Parameters
----------
ac : one.webclient.AlyxClient
session_path : str, pathlib.Path
The session path from which to determine the lab name.
alyx : one.webclient.AlyxClient
An AlyxClient instance for querying data repositories.
Returns
-------
list
The lab names associated with the local Globus endpoint ID.
str
The lab name associated with the session path subject.
See Also
--------
one.remote.globus.get_lab_from_endpoint_id
"""
globus_id = get_local_endpoint_id()
lab = ac.rest('labs', 'list', django=f'repositories__globus_endpoint_id,{globus_id}')
return [la['name'] for la in lab]
alyx = alyx or AlyxClient()
if not (ref := ConversionMixin.path2ref(session_path)):
raise ValueError(f'Failed to parse session path: {session_path}')

labs = [x['lab'] for x in alyx.rest('subjects', 'list', nickname=ref['subject'])]
if len(labs) == 0:
raise alferr.AlyxSubjectNotFound(ref['subject'])
elif len(labs) > 1: # More than one subject with this nickname
# use local endpoint ID to find the correct lab
endpoint_labs = get_lab_from_endpoint_id(alyx=alyx)
lab = next(x for x in labs if x in endpoint_labs)
else:
lab, = labs

return lab
70 changes: 45 additions & 25 deletions ibllib/pipes/local_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import importlib

from one.api import ONE
from one.webclient import AlyxClient
from one.remote.globus import get_lab_from_endpoint_id

from ibllib.io.extractors.base import get_pipeline, get_task_protocol, get_session_extractor_type
from ibllib.pipes import tasks, training_preprocessing, ephys_preprocessing
Expand Down Expand Up @@ -72,35 +74,52 @@ def report_health(one):
status.update(_get_volume_usage('/mnt/s0/Data', 'raid'))
status.update(_get_volume_usage('/', 'system'))

lab_names = get_lab(one.alyx)
lab_names = get_lab_from_endpoint_id(alyx=one.alyx)
for ln in lab_names:
one.alyx.json_field_update(endpoint='labs', uuid=ln, field_name='json', data=status)


def job_creator(root_path, one=None, dry=False, rerun=False, max_md5_size=None):
"""
Server function that will look for creation flags and for each:
1) create the sessions on Alyx
2) register the corresponding raw data files on Alyx
3) create the tasks to be run on Alyx
:param root_path: main path containing sessions or session path
:param one
:param dry
:param rerun
:param max_md5_size
:return:
Create new sessions and pipelines.
Server function that will look for 'raw_session.flag' files and for each:
1) create the session on Alyx
2) create the tasks to be run on Alyx
For legacy sessions the raw data are registered separately, instead of within a pipeline task.
Parameters
----------
root_path : str, pathlib.Path
Main path containing sessions or a session path.
one : one.api.OneAlyx
An ONE instance for registering the session(s).
dry : bool
If true, simply log the session_path(s) found, without registering anything.
rerun : bool
If true and session pipeline tasks already exist, set them all to waiting.
max_md5_size : int
(legacy sessions) The maximum file size to calculate the MD5 hash sum for.
Returns
-------
list of ibllib.pipes.tasks.Pipeline
The pipelines created.
list of dicts
A list of any datasets registered (only for legacy sessions)
"""
if not one:
one = ONE(cache_rest=None)
rc = IBLRegistrationClient(one=one)
flag_files = list(Path(root_path).glob('**/raw_session.flag'))
pipes = []
all_datasets = []
for flag_file in flag_files:
session_path = flag_file.parent
_logger.info(f'creating session for {session_path}')
if dry:
continue

try:
# if the subject doesn't exist in the database, skip
rc.register_session(session_path, file_list=False)
Expand All @@ -112,9 +131,9 @@ def job_creator(root_path, one=None, dry=False, rerun=False, max_md5_size=None):
else:
# Create legacy experiment description file
acquisition_description_legacy_session(session_path, save=True)
labs = ','.join(get_lab(one.alyx))
files, dsets = register_session_raw_data(session_path, one=one, max_md5_size=max_md5_size, labs=labs)
if dsets is not None:
lab = get_lab(session_path, one.alyx) # Can be set to None to do this Alyx-side if using ONE v1.20.1
_, dsets = register_session_raw_data(session_path, one=one, max_md5_size=max_md5_size, labs=lab)
if dsets:
all_datasets.extend(dsets)
pipe = _get_pipeline_class(session_path, one)
if pipe is None:
Expand All @@ -126,15 +145,17 @@ def job_creator(root_path, one=None, dry=False, rerun=False, max_md5_size=None):
rerun__status__in = ['Waiting']
pipe.create_alyx_tasks(rerun__status__in=rerun__status__in)
flag_file.unlink()
if pipe is not None:
pipes.append(pipe)
except Exception:
_logger.error(traceback.format_exc())
_logger.warning(f'Creating session / registering raw datasets {session_path} errored')
continue

return all_datasets
return pipes, all_datasets


def task_queue(mode='all', lab=None, one=None):
def task_queue(mode='all', lab=None, alyx=None):
"""
Query waiting jobs from the specified Lab
:param mode: Whether to return all waiting tasks, or only small or large (specified in LARGE_TASKS) jobs
Expand All @@ -143,18 +164,17 @@ def task_queue(mode='all', lab=None, one=None):
-------
"""
if one is None:
one = ONE(cache_rest=None)
alyx = alyx or AlyxClient(cache_rest=None)
if lab is None:
_logger.debug("Trying to infer lab from globus installation")
lab = get_lab(one.alyx)
_logger.debug('Trying to infer lab from globus installation')
lab = get_lab_from_endpoint_id(alyx=alyx)
if lab is None:
_logger.error("No lab provided or found")
_logger.error('No lab provided or found')
return # if the lab is none, this will return empty tasks each time
data_repo = get_local_data_repository(one)
data_repo = get_local_data_repository(alyx)
# Filter for tasks
tasks_all = one.alyx.rest('tasks', 'list', status='Waiting',
django=f'session__lab__name__in,{lab},data_repository__name,{data_repo}', no_cache=True)
tasks_all = alyx.rest('tasks', 'list', status='Waiting',
django=f'session__lab__name__in,{lab},data_repository__name,{data_repo}', no_cache=True)
if mode == 'all':
waiting_tasks = tasks_all
else:
Expand Down
Loading

0 comments on commit 98b96b0

Please sign in to comment.