Skip to content

Commit

Permalink
Merge pull request #9 from spapa013/main
Browse files Browse the repository at this point in the history
v0.1.8
  • Loading branch information
spapa013 authored Jan 9, 2024
2 parents a7d0777 + 41495c2 commit 0bcf320
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 5 deletions.
239 changes: 235 additions & 4 deletions datajoint_plus/table.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
"""
Extensions of DataJoint Table
"""
import inspect
import numpy as np
import uuid
import collections
import pandas
import itertools
from pathlib import Path

from .logging import getLogger

import datajoint as dj

from datajoint.errors import (
DuplicateError,
DataJointError,
UnknownAttributeError
)
from datajoint.expression import QueryExpression
from datajoint import blob
from datajoint_plus.hash import generate_table_id

from .utils import classproperty, goto
Expand Down Expand Up @@ -51,7 +64,7 @@ def goto(self, table_id=None, full_table_name=None, tid_attr=None, ftn_attr=None
if len(self)==1:
try:
table_id = self.fetch1('table_id')
except:
except DataJointError:
pass

if table_id is None and tid_attr is not None:
Expand All @@ -69,8 +82,8 @@ def goto(self, table_id=None, full_table_name=None, tid_attr=None, ftn_attr=None
if tid_attr is None:
tid_attr = 'table_id'
full_table_name = (ftn_lookup & {tid_attr: table_id}).fetch1('full_table_name')
except Exception as e:
raise Exception('If return_free_table=True and full_table_name is not provided, table_id and ftn_lookup must be provided.')
except DataJointError:
raise DataJointError('If return_free_table=True and full_table_name is not provided, table_id and ftn_lookup must be provided.')
return FreeTable(self.connection, full_table_name)

else:
Expand All @@ -86,6 +99,224 @@ def _table_log(self):
def table_id(cls):
return generate_table_id(cls.full_table_name)

def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields=False, allow_direct_insert=None):
"""
Insert a collection of rows.
:param rows: An iterable where an element is a numpy record, a dict-like object, a pandas.DataFrame, a sequence,
or a query expression with the same heading as table self.
:param replace: If True, replaces the existing tuple.
:param skip_duplicates: If True, silently skip duplicate inserts.
:param ignore_extra_fields: If False, fields that are not in the heading raise error.
:param allow_direct_insert: applies only in auto-populated tables.
If False (default), insert are allowed only from inside the make callback.
Example::
>>> relation.insert([
>>> dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"),
>>> dict(subject_id=8, species="mouse", date_of_birth="2014-09-02")])
"""

if isinstance(rows, pandas.DataFrame):
# drop 'extra' synthetic index for 1-field index case -
# frames with more advanced indices should be prepared by user.
rows = rows.reset_index(
drop=len(rows.index.names) == 1 and not rows.index.names[0]
).to_records(index=False)

# prohibit direct inserts into auto-populated tables
if not allow_direct_insert and not getattr(self, '_allow_insert', True): # allow_insert is only used in AutoPopulate
raise DataJointError(
'Inserts into an auto-populated table can only done inside its make method during a populate call.'
' To override, set keyword argument allow_direct_insert=True.')

heading = self.heading
if inspect.isclass(rows) and issubclass(rows, QueryExpression): # instantiate if a class
rows = rows()
if isinstance(rows, QueryExpression):
# insert from select
if not ignore_extra_fields:
try:
raise DataJointError(
"Attribute %s not found. To ignore extra attributes in insert, set ignore_extra_fields=True." %
next(name for name in rows.heading if name not in heading))
except StopIteration:
pass
fields = list(name for name in rows.heading if name in heading)
query = '{command} INTO {table} ({fields}) {select}{duplicate}'.format(
command='REPLACE' if replace else 'INSERT',
fields='`' + '`,`'.join(fields) + '`',
table=self.full_table_name,
select=rows.make_sql(select_fields=fields),
duplicate=(' ON DUPLICATE KEY UPDATE `{pk}`={table}.`{pk}`'.format(
table=self.full_table_name, pk=self.primary_key[0])
if skip_duplicates else ''))
self.connection.query(query)
return

if heading.attributes is None:
logger.warning('Could not access table {table}'.format(table=self.full_table_name))
return

field_list = None # ensures that all rows have the same attributes in the same order as the first row.

def make_row_to_insert(row):
"""
:param row: A tuple to insert
:return: a dict with fields 'names', 'placeholders', 'values'
"""
def make_placeholder(name, value):
"""
For a given attribute `name` with `value`, return its processed value or value placeholder
as a string to be included in the query and the value, if any, to be submitted for
processing by mysql API.
:param name: name of attribute to be inserted
:param value: value of attribute to be inserted
"""
if ignore_extra_fields and name not in heading:
return None
attr = heading[name]
if attr.adapter:
value = attr.adapter.put(value)
if value is None or (attr.numeric and (value == '' or np.isnan(float(value)))):
# set default value
placeholder, value = 'DEFAULT', None
else: # not NULL
placeholder = '%s'
if attr.uuid:
if not isinstance(value, uuid.UUID):
try:
value = uuid.UUID(value)
except (AttributeError, ValueError):
raise DataJointError(
'badly formed UUID value {v} for attribute `{n}`'.format(v=value, n=name)) from None
value = value.bytes
elif attr.is_blob:
value = blob.pack(value)
value = self.external[attr.store].put(value).bytes if attr.is_external else value
elif attr.is_attachment:
attachment_path = Path(value)
if attr.is_external:
# value is hash of contents
value = self.external[attr.store].upload_attachment(attachment_path).bytes
else:
# value is filename + contents
value = str.encode(attachment_path.name) + b'\0' + attachment_path.read_bytes()
elif attr.is_filepath:
value = self.external[attr.store].upload_filepath(value).bytes
elif attr.numeric:
value = str(int(value) if isinstance(value, bool) else value)
return name, placeholder, value

def check_fields(fields):
"""
Validates that all items in `fields` are valid attributes in the heading
:param fields: field names of a tuple
"""
if field_list is None:
if not ignore_extra_fields:
for field in fields:
if field not in heading:
raise KeyError(u'`{0:s}` is not in the table heading'.format(field))
elif set(field_list) != set(fields).intersection(heading.names):
raise DataJointError('Attempt to insert rows with different fields')

if isinstance(row, np.void): # np.array
check_fields(row.dtype.fields)
attributes = [make_placeholder(name, row[name])
for name in heading if name in row.dtype.fields]
elif isinstance(row, collections.abc.Mapping): # dict-based
check_fields(row)
attributes = [make_placeholder(name, row[name]) for name in heading if name in row]
else: # positional
try:
if len(row) != len(heading):
raise DataJointError(
'Invalid insert argument. Incorrect number of attributes: '
'{given} given; {expected} expected'.format(
given=len(row), expected=len(heading)))
except TypeError:
raise DataJointError('Datatype %s cannot be inserted' % type(row))
else:
attributes = [make_placeholder(name, value) for name, value in zip(heading, row)]
if ignore_extra_fields:
attributes = [a for a in attributes if a is not None]

assert len(attributes), 'Empty tuple'
row_to_insert = dict(zip(('names', 'placeholders', 'values'), zip(*attributes)))
nonlocal field_list
if field_list is None:
# first row sets the composition of the field list
field_list = row_to_insert['names']
else:
# reorder attributes in row_to_insert to match field_list
order = list(row_to_insert['names'].index(field) for field in field_list)
row_to_insert['names'] = list(row_to_insert['names'][i] for i in order)
row_to_insert['placeholders'] = list(row_to_insert['placeholders'][i] for i in order)
row_to_insert['values'] = list(row_to_insert['values'][i] for i in order)

return row_to_insert

rows = list(make_row_to_insert(row) for row in rows)
if rows:
try:
query = "{command} INTO {destination}(`{fields}`) VALUES {placeholders}{duplicate}".format(
command='REPLACE' if replace else 'INSERT',
destination=self.from_clause,
fields='`,`'.join(field_list),
placeholders=','.join('(' + ','.join(row['placeholders']) + ')' for row in rows),
duplicate=(' ON DUPLICATE KEY UPDATE `{pk}`=`{pk}`'.format(pk=self.primary_key[0])
if skip_duplicates else ''))
self.connection.query(query, args=list(
itertools.chain.from_iterable((v for v in r['values'] if v is not None) for r in rows)))
except UnknownAttributeError as err:
raise err.suggest('To ignore extra fields in insert, set ignore_extra_fields=True') from None
except DuplicateError as err:
raise err.suggest('To ignore duplicate entries in insert, set skip_duplicates=True') from None

def _update(self, attrname, value=None):
"""
Updates a field in an existing tuple. This is not a datajoyous operation and should not be used
routinely. Relational database maintain referential integrity on the level of a tuple. Therefore,
the UPDATE operator can violate referential integrity. The datajoyous way to update information is
to delete the entire tuple and insert the entire update tuple.
Safety constraints:
1. self must be restricted to exactly one tuple
2. the update attribute must not be in primary key
Example:
>>> (v2p.Mice() & key).update('mouse_dob', '2011-01-01')
>>> (v2p.Mice() & key).update( 'lens') # set the value to NULL
"""
if len(self) != 1:
raise DataJointError('Update is only allowed on one tuple at a time')
if attrname not in self.heading:
raise DataJointError('Invalid attribute name')
if attrname in self.heading.primary_key:
raise DataJointError('Cannot update a key value.')

attr = self.heading[attrname]

if attr.is_blob:
value = blob.pack(value)
placeholder = '%s'
elif attr.numeric:
if value is None or np.isnan(float(value)): # nans are turned into NULLs
placeholder = 'NULL'
value = None
else:
placeholder = '%s'
value = str(int(value) if isinstance(value, bool) else value)
else:
placeholder = '%s' if value is not None else 'NULL'
command = "UPDATE {full_table_name} SET `{attrname}`={placeholder} {where_clause}".format(
full_table_name=self.from_clause,
attrname=attrname,
placeholder=placeholder,
where_clause=self.where_clause)
self.connection.query(command, args=(value, ) if value is not None else ())


class FreeTable(Table, dj.FreeTable):
"""
Expand Down
2 changes: 1 addition & 1 deletion datajoint_plus/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.7"
__version__ = "0.1.8"

0 comments on commit 0bcf320

Please sign in to comment.