Skip to content

Commit

Permalink
Merge pull request #7 from spapa013/main
Browse files Browse the repository at this point in the history
v0.1.6
  • Loading branch information
spapa013 authored Oct 3, 2022
2 parents 48ac438 + f260525 commit b98e964
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 34 deletions.
5 changes: 3 additions & 2 deletions datajoint_plus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
'format_table_name', 'split_full_table_name', 'make_store_dict', 'register_externals',
'generate_hash', 'validate_and_generate_hash', 'parse_definition',
'reform_definition', 'errors', 'free_table',
'basicConfig', 'getLogger', 'LogFileManager']
'basicConfig', 'getLogger', 'LogFileManager', 'BaseMaster', 'BasePart', 'UserTable']



Expand All @@ -34,14 +34,15 @@

# from DataJointPlus
from . import errors
from .base import BaseMaster, BasePart
from .compatibility import add_datajoint_plus, reassign_master_attribute
from .config import config
from .hash import generate_hash, validate_and_generate_hash
from .heading import parse_definition, reform_definition
from .logging import LogFileManager, basicConfig, getLogger
from .schema import DataJointPlusModule, Schema
from .table import FreeTable as free_table
from .user_tables import Computed, Lookup, Part, Manual
from .user_tables import Computed, Lookup, Part, Manual, UserTable
from .utils import (add_objects, check_if_latest_version,
enable_datajoint_flags, format_table_name, make_store_dict,
register_externals, split_full_table_name)
Expand Down
6 changes: 4 additions & 2 deletions datajoint_plus/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,8 @@ def parts(cls, as_objects=False, as_cls=False, reload_dependencies=False):
:returns: list
"""
cls.load_dependencies(force=reload_dependencies)
if not cls.connection.dependencies._loaded or reload_dependencies:
cls.load_dependencies()

cls_parts = [getattr(cls, d) for d in dir(cls) if inspect.isclass(getattr(cls, d)) and issubclass(getattr(cls, d), dj.Part)]
for cls_part in [p.full_table_name for p in cls_parts]:
Expand Down Expand Up @@ -726,7 +727,8 @@ def restrict_parts(cls, part_restr={}, include_parts=None, exclude_parts=None, f
:param filter_out_len_zero: (bool) If True, included parts must have greater than zero rows after restriction is applied.
:param reload_dependencies: (bool) reloads DataJoint graph dependencies.
"""
assert cls.has_parts(reload_dependencies=reload_dependencies), 'No part tables found. If you are expecting part tables, try with reload_dependencies=True.'
if not cls.has_parts(reload_dependencies=reload_dependencies):
logger.warning('No part tables found.')

if include_parts is None:
parts = cls.parts(as_cls=True)
Expand Down
45 changes: 21 additions & 24 deletions datajoint_plus/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ class Schema(dj.Schema):
Additional params:
:param load_dependencies (bool): Loads the DataJoint graph.
"""
def __init__(self, schema_name, context=None, load_dependencies=True, *, connection=None, create_schema=True, create_tables=True):
def __init__(self, schema_name, context=None, load_dependencies=False, update_table_log=False, *, connection=None, create_schema=True, create_tables=True):
super().__init__(schema_name=schema_name, context=context, connection=connection, create_schema=create_schema, create_tables=create_tables)

# attempt to update ~tables
try:
self._tables = None
for table_name in self.list_tables():
full_table_name = reform_full_table_name(self.database, table_name)
self.tables(generate_table_id(full_table_name), full_table_name, action='add')
for key in self._tables:
_, name = split_full_table_name(key['full_table_name'])
if name not in self.list_tables():
self.tables(full_table_name=key['full_table_name'], action='delete')
except:
pass
if update_table_log:
try:
self._tables = None
for table_name in self.list_tables():
full_table_name = reform_full_table_name(self.database, table_name)
self.tables(generate_table_id(full_table_name), full_table_name, action='add')
for key in self._tables:
_, name = split_full_table_name(key['full_table_name'])
if name not in self.list_tables():
self.tables(full_table_name=key['full_table_name'], action='delete')
except:
logger.warning('Could not update schema.tables')

if load_dependencies:
self.load_dependencies(verbose=False)
Expand Down Expand Up @@ -88,7 +88,7 @@ class VirtualModule(types.ModuleType):
A virtual module which will contain context for schema.
"""
def __init__(self, module_name, schema_name, *, create_schema=False,
create_tables=False, connection=None, add_objects=None):
create_tables=False, connection=None, add_objects=None, load_dependencies=False):
"""
Creates a python module with the given name from the name of a schema on the server and
automatically adds classes to it corresponding to the tables in the schema.
Expand All @@ -101,8 +101,8 @@ def __init__(self, module_name, schema_name, *, create_schema=False,
:return: the python module containing classes from the schema object and the table classes
"""
super().__init__(name=module_name)
_schema = Schema(schema_name, create_schema=create_schema, create_tables=create_tables,
connection=connection)
_schema = Schema(schema_name, create_schema=create_schema, load_dependencies=load_dependencies,
create_tables=create_tables, connection=connection)
if add_objects:
self.__dict__.update(add_objects)
self.__dict__['schema'] = _schema
Expand All @@ -113,7 +113,7 @@ class DataJointPlusModule(VirtualModule):
"""
DataJointPlus extension of DataJoint virtual module with the added ability to instantiate from an existing module.
"""
def __init__(self, module_name=None, schema_name=None, module=None, schema_obj_name=None, add_externals=None, add_objects=None, create_schema=False, create_tables=False, connection=None, spawn_missing_classes=True, load_dependencies=True, enable_dj_flags=True, warn=True):
def __init__(self, module_name=None, schema_name=None, module=None, schema_obj_name=None, add_externals=None, add_objects=None, create_schema=False, create_tables=False, connection=None, spawn_missing_classes=True, load_dependencies=False, enable_dj_flags=True, warn=True):
"""
Add DataJointPlus methods to all DataJoint user tables in a DataJoint virtual module or to an existing module.
Expand All @@ -140,9 +140,6 @@ def __init__(self, module_name=None, schema_name=None, module=None, schema_obj_n
assert not module, 'Provide either schema_name or module but not both.'
super().__init__(module_name=module_name if module_name else schema_name, schema_name=schema_name, add_objects=add_objects, create_schema=create_schema, create_tables=create_tables, connection=connection)

if load_dependencies:
self.load_dependencies(verbose=False)

elif module:
super(dj.VirtualModule, self).__init__(name=module.__name__)
if module_name:
Expand All @@ -164,9 +161,6 @@ def __init__(self, module_name=None, schema_name=None, module=None, schema_obj_n
if spawn_missing_classes:
schema_obj.spawn_missing_classes(context=self.__dict__)

if load_dependencies:
self.load_dependencies(verbose=False)

if add_objects:
self.__dict__.update(add_objects)

Expand All @@ -178,7 +172,10 @@ def __init__(self, module_name=None, schema_name=None, module=None, schema_obj_n

if enable_dj_flags:
enable_datajoint_flags()


if load_dependencies:
self.load_dependencies(verbose=False)

add_datajoint_plus(self)

def load_dependencies(self, verbose=True):
Expand Down
15 changes: 10 additions & 5 deletions datajoint_plus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import pandas as pd
import requests
from datajoint.errors import _switch_adapted_types, _switch_filepath_types
from datajoint.errors import _support_adapted_types, _switch_adapted_types, _support_filepath_types, _switch_filepath_types
from datajoint.table import QueryExpression
from datajoint.user_tables import UserTable
from IPython.display import display
Expand Down Expand Up @@ -131,15 +131,20 @@ def load_dependencies(connection, force=False, verbose=True):
connection.dependencies.load(force=force)


def enable_datajoint_flags(enable_python_native_blobs=True):
def enable_datajoint_flags(enable_python_native_blobs=True, support_adapted_types=True, support_filepath_types=True):
"""
Enable experimental datajoint features
These flags are required by 0.12.0+ (for now).
"""
config['enable_python_native_blobs'] = enable_python_native_blobs
_switch_filepath_types(True)
_switch_adapted_types(True)
if config['enable_python_native_blobs'] != enable_python_native_blobs:
config['enable_python_native_blobs'] = enable_python_native_blobs

if _support_adapted_types() != support_adapted_types:
_switch_adapted_types(support_adapted_types)

if _support_filepath_types() != support_filepath_types:
_switch_filepath_types(support_filepath_types)


def register_externals(external_stores):
Expand Down
2 changes: 1 addition & 1 deletion datajoint_plus/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.5"
__version__ = "0.1.6"

0 comments on commit b98e964

Please sign in to comment.