From 58a3c145fcfd846a1a6a6e797b7827a8634ca9b4 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 11 Oct 2024 14:01:00 +0300 Subject: [PATCH] Update cache from records in OneAlyx.search --- one/api.py | 50 +++++++++++++++++++++++++++++------- one/tests/test_alyxclient.py | 22 ++++++++++++++++ one/tests/test_one.py | 21 ++++++++++++++- one/webclient.py | 43 +++++++++++++++++++++++++++++++ 4 files changed, 126 insertions(+), 10 deletions(-) diff --git a/one/api.py b/one/api.py index 412ee206..17c2ac1f 100644 --- a/one/api.py +++ b/one/api.py @@ -3,6 +3,7 @@ import urllib.parse import warnings import logging +from weakref import WeakMethod from datetime import datetime, timedelta from functools import lru_cache, partial from inspect import unwrap @@ -98,14 +99,17 @@ def search_terms(self, query_type=None) -> tuple: def _reset_cache(self): """Replace the cache object with a Bunch that contains the right fields.""" - self._cache = Bunch({'_meta': { - 'expired': False, - 'created_time': None, - 'loaded_time': None, - 'modified_time': None, - 'saved_time': None, - 'raw': {} # map of original table metadata - }}) + self._cache = Bunch({ + 'datasets': pd.DataFrame(columns=DATASETS_COLUMNS).set_index(['eid', 'id']), + 'sessions': pd.DataFrame(columns=SESSIONS_COLUMNS).set_index('id'), + '_meta': { + 'expired': False, + 'created_time': None, + 'loaded_time': None, + 'modified_time': None, + 'saved_time': None, + 'raw': {}} # map of original table metadata + }) def load_cache(self, tables_dir=None, **kwargs): """ @@ -187,7 +191,7 @@ def _save_cache(self, save_dir=None, force=False): If True, the cache is saved regardless of modification time. """ TIMEOUT = 5 # Delete lock file this many seconds after creation/modification or waiting - lock_file = Path(self.cache_dir).joinpath('.cache.lock') + lock_file = Path(self.cache_dir).joinpath('.cache.lock') # TODO use iblutil method here save_dir = Path(save_dir or self.cache_dir) meta = self._cache['_meta'] modified = meta.get('modified_time') or datetime.min @@ -2271,6 +2275,18 @@ def search(self, details=False, query_type=None, **kwargs): params.pop('django') # Make GET request ses = self.alyx.rest(self._search_endpoint, 'list', **params) + + # Update cache table with results + if len(ses) == 0: + pass # no need to update cache here + elif isinstance(ses, list): # not a paginated response + self._update_sessions_table(ses) + else: + # populate first page + self._update_sessions_table(ses._cache[:ses.limit]) + # Add callback for updating cache on future fetches + ses.add_callback(WeakMethod(self._update_sessions_table)) + # LazyId only transforms records when indexed eids = util.LazyId(ses) if not details: @@ -2284,6 +2300,22 @@ def _add_date(records): return eids, util.LazyId(ses, func=_add_date) + def _update_sessions_table(self, session_records): + """Update the sessions tables with a list of session records. + + Parameters + ---------- + session_records : list of dict + A list of session records from the /sessions list endpoint. + + Returns + ------- + datetime.datetime: + A timestamp of when the cache was updated. + """ + df = pd.DataFrame(next(zip(*map(util.ses2records, session_records)))) + return self._update_cache_from_records(sessions=df) + def _download_datasets(self, dsets, **kwargs) -> List[Path]: """ Download a single or multitude of datasets if stored on AWS, otherwise calls diff --git a/one/tests/test_alyxclient.py b/one/tests/test_alyxclient.py index 09615e6e..1b352b89 100644 --- a/one/tests/test_alyxclient.py +++ b/one/tests/test_alyxclient.py @@ -3,6 +3,7 @@ from unittest import mock import urllib.parse import random +import weakref import os import one.webclient as wc import one.params @@ -498,12 +499,24 @@ def test_paginated_response(self): self.assertTrue(not any(pg._cache[lim:])) self.assertIs(pg.alyx, alyx) + # Check adding callbacks + self.assertRaises(TypeError, pg.add_callback, None) + wf = mock.Mock(spec_set=weakref.ref) + cb1, cb2 = mock.MagicMock(), wf() + pg.add_callback(cb1) + pg.add_callback(wf) + self.assertEqual(2, len(pg._callbacks)) + # Check fetching cached item with +ve int self.assertEqual({'id': 1}, pg[1]) alyx._generic_request.assert_not_called() + for cb in [cb1, cb2]: + cb.assert_not_called() # Check fetching cached item with +ve slice self.assertEqual([{'id': 1}, {'id': 2}], pg[1:3]) alyx._generic_request.assert_not_called() + for cb in [cb1, cb2]: + cb.assert_not_called() # Check fetching cached item with -ve int self.assertEqual({'id': 100}, pg[-1900]) alyx._generic_request.assert_not_called() @@ -518,6 +531,10 @@ def test_paginated_response(self): self.assertEqual(res['results'], pg._cache[offset:offset + lim]) alyx._generic_request.assert_called_once_with(requests.get, mock.ANY, clobber=True) self._check_get_query(alyx._generic_request.call_args, lim, offset) + for cb in [cb1, cb2]: + cb.assert_called_once_with(res['results']) + # Check that dead weakreaf will be removed from the list on next call + wf.return_value = None # Check fetching uncached item with -ve int offset = lim * 3 res['results'] = [{'id': i} for i in range(offset, offset + lim)] @@ -527,6 +544,7 @@ def test_paginated_response(self): self.assertEqual(res['results'], pg._cache[offset:offset + lim]) alyx._generic_request.assert_called_with(requests.get, mock.ANY, clobber=True) self._check_get_query(alyx._generic_request.call_args, lim, offset) + self.assertEqual(1, len(pg._callbacks), 'failed to remove weakref callback') # Check fetching uncached item with +ve slice offset = lim * 5 res['results'] = [{'id': i} for i in range(offset, offset + lim)] @@ -548,6 +566,10 @@ def test_paginated_response(self): self.assertEqual(expected_calls := 4, alyx._generic_request.call_count) self.assertEqual((expected_calls + 1) * lim, sum(list(map(bool, pg._cache)))) + # Check callbacks cleared when cache fully populated + self.assertTrue(all(map(bool, pg))) + self.assertEqual(0, len(pg._callbacks)) + def _check_get_query(self, call_args, limit, offset): """Check URL get query contains the expected limit and offset params.""" (_, url), _ = call_args diff --git a/one/tests/test_one.py b/one/tests/test_one.py index 13e4fec8..4824597d 100644 --- a/one/tests/test_one.py +++ b/one/tests/test_one.py @@ -1475,8 +1475,17 @@ def test_list_datasets(self): def test_search(self): """Test OneAlyx.search method in remote mode.""" + # Modify sessions dataframe so we can check that the records get updated + records = self.one._cache.sessions[self.one._cache.sessions.subject == 'SWC_043'] + self.one._cache.sessions.loc[records.index, 'lab'] = 'foolab' # change a field + self.one._cache.sessions.drop(self.eid, inplace=True) # remove a row + + # Check remote seach of subject eids = self.one.search(subject='SWC_043', query_type='remote') self.assertIn(self.eid, list(eids)) + updated = self.one._cache.sessions[self.one._cache.sessions.subject == 'SWC_043'] + self.assertCountEqual(eids, updated.index) + self.assertFalse('foolab' in updated['lab']) eids, det = self.one.search(subject='SWC_043', query_type='remote', details=True) correct = len(det) == len(eids) and 'url' in det[0] and det[0]['url'].endswith(eids[0]) @@ -1501,10 +1510,20 @@ def test_search(self): dates = set(map(lambda x: self.one.get_details(x)['date'], eids)) self.assertTrue(dates <= set(date_range)) - # Test limit arg and LazyId + # Test limit arg, LazyId, and update with paginated response callback + self.one._reset_cache() # Remove sessions table + assert self.one._cache.sessions.empty eids = self.one.search(date='2020-03-23', limit=2, query_type='remote') + self.assertEqual(2, len(self.one._cache.sessions), + 'failed to update cache with first page of search results') self.assertIsInstance(eids, LazyId) + assert len(eids) > 5, 'in order to check paginated response callback we need several pages' + e = eids[-3] # access an uncached value + self.assertEqual( + 4, len(self.one._cache.sessions), 'failed to update cache after page access') + self.assertTrue(e in self.one._cache.sessions.index) self.assertTrue(all(len(x) == 36 for x in eids)) + self.assertEqual(len(eids), len(self.one._cache.sessions)) # Test laboratory kwarg eids = self.one.search(laboratory='hoferlab', query_type='remote') diff --git a/one/webclient.py b/one/webclient.py index d6b749d7..64b83f68 100644 --- a/one/webclient.py +++ b/one/webclient.py @@ -40,6 +40,7 @@ from typing import Optional from datetime import datetime, timedelta from pathlib import Path +from weakref import ReferenceType import warnings import hashlib import zipfile @@ -206,6 +207,23 @@ def __init__(self, alyx, rep, cache_args=None): # fill the cache with results of the query for i in range(self.limit): self._cache[i] = rep['results'][i] + self._callbacks = set() + + def add_callback(self, cb): + """Add a callback function to use each time a new page is fetched. + + The callback function will be called with the page results each time :meth:`populate` + is called. + + Parameters + ---------- + cb : callable + A callable that takes the results of each paginated resonse. + """ + if not callable(cb): + raise TypeError(f'Expected type "callable", got "{type(cb)}" instead') + else: + self._callbacks.add(cb) def __len__(self): return self.count @@ -222,6 +240,16 @@ def __getitem__(self, item): return self._cache[item] def populate(self, idx): + """Populate response cache with new page of results. + + Fetches the specific page of results containing the index passed and populates + stores the results in the :prop:`_cache` property. + + Parameters + ---------- + idx : int + The index of a given record to fetch. + """ offset = self.limit * math.floor(idx / self.limit) query = update_url_params(self.query, {'limit': self.limit, 'offset': offset}) res = self.alyx._generic_request(requests.get, query, **self._cache_args) @@ -231,6 +259,21 @@ def populate(self, idx): f'results may be inconsistent', RuntimeWarning) for i, r in enumerate(res['results'][:self.count - offset]): self._cache[i + offset] = res['results'][i] + # Notify callbacks + pending_removal = [] + for callback in self._callbacks: + # Handle weak reference callbacks first + if isinstance(callback, ReferenceType): + wf = callback + if (callback := wf()) is None: + pending_removal.append(wf) + continue + callback(res['results']) + for wf in pending_removal: + self._callbacks.discard(wf) + # When cache is complete, clear our callbacks + if all(reversed(self._cache)): + self._callbacks.clear() def __iter__(self): for i in range(self.count):