From 11290b7b9c288c247a2a802c706592524685cdd9 Mon Sep 17 00:00:00 2001 From: spapa013 Date: Tue, 9 Jan 2024 09:54:35 -0600 Subject: [PATCH 1/2] move and update datajoint insertion code to djp to remove np.bool error --- datajoint_plus/table.py | 239 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 235 insertions(+), 4 deletions(-) diff --git a/datajoint_plus/table.py b/datajoint_plus/table.py index 3469030..60ee7b7 100644 --- a/datajoint_plus/table.py +++ b/datajoint_plus/table.py @@ -1,11 +1,24 @@ """ Extensions of DataJoint Table """ +import inspect +import numpy as np +import uuid +import collections +import pandas +import itertools +from pathlib import Path from .logging import getLogger import datajoint as dj - +from datajoint.errors import ( + DuplicateError, + DataJointError, + UnknownAttributeError +) +from datajoint.expression import QueryExpression +from datajoint import blob from datajoint_plus.hash import generate_table_id from .utils import classproperty, goto @@ -51,7 +64,7 @@ def goto(self, table_id=None, full_table_name=None, tid_attr=None, ftn_attr=None if len(self)==1: try: table_id = self.fetch1('table_id') - except: + except DataJointError: pass if table_id is None and tid_attr is not None: @@ -69,8 +82,8 @@ def goto(self, table_id=None, full_table_name=None, tid_attr=None, ftn_attr=None if tid_attr is None: tid_attr = 'table_id' full_table_name = (ftn_lookup & {tid_attr: table_id}).fetch1('full_table_name') - except Exception as e: - raise Exception('If return_free_table=True and full_table_name is not provided, table_id and ftn_lookup must be provided.') + except DataJointError: + raise DataJointError('If return_free_table=True and full_table_name is not provided, table_id and ftn_lookup must be provided.') return FreeTable(self.connection, full_table_name) else: @@ -86,6 +99,224 @@ def _table_log(self): def table_id(cls): return generate_table_id(cls.full_table_name) + def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields=False, allow_direct_insert=None): + """ + Insert a collection of rows. + + :param rows: An iterable where an element is a numpy record, a dict-like object, a pandas.DataFrame, a sequence, + or a query expression with the same heading as table self. + :param replace: If True, replaces the existing tuple. + :param skip_duplicates: If True, silently skip duplicate inserts. + :param ignore_extra_fields: If False, fields that are not in the heading raise error. + :param allow_direct_insert: applies only in auto-populated tables. + If False (default), insert are allowed only from inside the make callback. + + Example:: + >>> relation.insert([ + >>> dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"), + >>> dict(subject_id=8, species="mouse", date_of_birth="2014-09-02")]) + """ + + if isinstance(rows, pandas.DataFrame): + # drop 'extra' synthetic index for 1-field index case - + # frames with more advanced indices should be prepared by user. + rows = rows.reset_index( + drop=len(rows.index.names) == 1 and not rows.index.names[0] + ).to_records(index=False) + + # prohibit direct inserts into auto-populated tables + if not allow_direct_insert and not getattr(self, '_allow_insert', True): # allow_insert is only used in AutoPopulate + raise DataJointError( + 'Inserts into an auto-populated table can only done inside its make method during a populate call.' + ' To override, set keyword argument allow_direct_insert=True.') + + heading = self.heading + if inspect.isclass(rows) and issubclass(rows, QueryExpression): # instantiate if a class + rows = rows() + if isinstance(rows, QueryExpression): + # insert from select + if not ignore_extra_fields: + try: + raise DataJointError( + "Attribute %s not found. To ignore extra attributes in insert, set ignore_extra_fields=True." % + next(name for name in rows.heading if name not in heading)) + except StopIteration: + pass + fields = list(name for name in rows.heading if name in heading) + query = '{command} INTO {table} ({fields}) {select}{duplicate}'.format( + command='REPLACE' if replace else 'INSERT', + fields='`' + '`,`'.join(fields) + '`', + table=self.full_table_name, + select=rows.make_sql(select_fields=fields), + duplicate=(' ON DUPLICATE KEY UPDATE `{pk}`={table}.`{pk}`'.format( + table=self.full_table_name, pk=self.primary_key[0]) + if skip_duplicates else '')) + self.connection.query(query) + return + + if heading.attributes is None: + logger.warning('Could not access table {table}'.format(table=self.full_table_name)) + return + + field_list = None # ensures that all rows have the same attributes in the same order as the first row. + + def make_row_to_insert(row): + """ + :param row: A tuple to insert + :return: a dict with fields 'names', 'placeholders', 'values' + """ + def make_placeholder(name, value): + """ + For a given attribute `name` with `value`, return its processed value or value placeholder + as a string to be included in the query and the value, if any, to be submitted for + processing by mysql API. + :param name: name of attribute to be inserted + :param value: value of attribute to be inserted + """ + if ignore_extra_fields and name not in heading: + return None + attr = heading[name] + if attr.adapter: + value = attr.adapter.put(value) + if value is None or (attr.numeric and (value == '' or np.isnan(float(value)))): + # set default value + placeholder, value = 'DEFAULT', None + else: # not NULL + placeholder = '%s' + if attr.uuid: + if not isinstance(value, uuid.UUID): + try: + value = uuid.UUID(value) + except (AttributeError, ValueError): + raise DataJointError( + 'badly formed UUID value {v} for attribute `{n}`'.format(v=value, n=name)) from None + value = value.bytes + elif attr.is_blob: + value = blob.pack(value) + value = self.external[attr.store].put(value).bytes if attr.is_external else value + elif attr.is_attachment: + attachment_path = Path(value) + if attr.is_external: + # value is hash of contents + value = self.external[attr.store].upload_attachment(attachment_path).bytes + else: + # value is filename + contents + value = str.encode(attachment_path.name) + b'\0' + attachment_path.read_bytes() + elif attr.is_filepath: + value = self.external[attr.store].upload_filepath(value).bytes + elif attr.numeric: + value = str(int(value) if isinstance(value, bool) else value) + return name, placeholder, value + + def check_fields(fields): + """ + Validates that all items in `fields` are valid attributes in the heading + :param fields: field names of a tuple + """ + if field_list is None: + if not ignore_extra_fields: + for field in fields: + if field not in heading: + raise KeyError(u'`{0:s}` is not in the table heading'.format(field)) + elif set(field_list) != set(fields).intersection(heading.names): + raise DataJointError('Attempt to insert rows with different fields') + + if isinstance(row, np.void): # np.array + check_fields(row.dtype.fields) + attributes = [make_placeholder(name, row[name]) + for name in heading if name in row.dtype.fields] + elif isinstance(row, collections.abc.Mapping): # dict-based + check_fields(row) + attributes = [make_placeholder(name, row[name]) for name in heading if name in row] + else: # positional + try: + if len(row) != len(heading): + raise DataJointError( + 'Invalid insert argument. Incorrect number of attributes: ' + '{given} given; {expected} expected'.format( + given=len(row), expected=len(heading))) + except TypeError: + raise DataJointError('Datatype %s cannot be inserted' % type(row)) + else: + attributes = [make_placeholder(name, value) for name, value in zip(heading, row)] + if ignore_extra_fields: + attributes = [a for a in attributes if a is not None] + + assert len(attributes), 'Empty tuple' + row_to_insert = dict(zip(('names', 'placeholders', 'values'), zip(*attributes))) + nonlocal field_list + if field_list is None: + # first row sets the composition of the field list + field_list = row_to_insert['names'] + else: + # reorder attributes in row_to_insert to match field_list + order = list(row_to_insert['names'].index(field) for field in field_list) + row_to_insert['names'] = list(row_to_insert['names'][i] for i in order) + row_to_insert['placeholders'] = list(row_to_insert['placeholders'][i] for i in order) + row_to_insert['values'] = list(row_to_insert['values'][i] for i in order) + + return row_to_insert + + rows = list(make_row_to_insert(row) for row in rows) + if rows: + try: + query = "{command} INTO {destination}(`{fields}`) VALUES {placeholders}{duplicate}".format( + command='REPLACE' if replace else 'INSERT', + destination=self.from_clause, + fields='`,`'.join(field_list), + placeholders=','.join('(' + ','.join(row['placeholders']) + ')' for row in rows), + duplicate=(' ON DUPLICATE KEY UPDATE `{pk}`=`{pk}`'.format(pk=self.primary_key[0]) + if skip_duplicates else '')) + self.connection.query(query, args=list( + itertools.chain.from_iterable((v for v in r['values'] if v is not None) for r in rows))) + except UnknownAttributeError as err: + raise err.suggest('To ignore extra fields in insert, set ignore_extra_fields=True') from None + except DuplicateError as err: + raise err.suggest('To ignore duplicate entries in insert, set skip_duplicates=True') from None + + def _update(self, attrname, value=None): + """ + Updates a field in an existing tuple. This is not a datajoyous operation and should not be used + routinely. Relational database maintain referential integrity on the level of a tuple. Therefore, + the UPDATE operator can violate referential integrity. The datajoyous way to update information is + to delete the entire tuple and insert the entire update tuple. + + Safety constraints: + 1. self must be restricted to exactly one tuple + 2. the update attribute must not be in primary key + + Example: + >>> (v2p.Mice() & key).update('mouse_dob', '2011-01-01') + >>> (v2p.Mice() & key).update( 'lens') # set the value to NULL + """ + if len(self) != 1: + raise DataJointError('Update is only allowed on one tuple at a time') + if attrname not in self.heading: + raise DataJointError('Invalid attribute name') + if attrname in self.heading.primary_key: + raise DataJointError('Cannot update a key value.') + + attr = self.heading[attrname] + + if attr.is_blob: + value = blob.pack(value) + placeholder = '%s' + elif attr.numeric: + if value is None or np.isnan(float(value)): # nans are turned into NULLs + placeholder = 'NULL' + value = None + else: + placeholder = '%s' + value = str(int(value) if isinstance(value, bool) else value) + else: + placeholder = '%s' if value is not None else 'NULL' + command = "UPDATE {full_table_name} SET `{attrname}`={placeholder} {where_clause}".format( + full_table_name=self.from_clause, + attrname=attrname, + placeholder=placeholder, + where_clause=self.where_clause) + self.connection.query(command, args=(value, ) if value is not None else ()) + class FreeTable(Table, dj.FreeTable): """ From 41495c211db1b23df5bc70e978903e775f35b9bb Mon Sep 17 00:00:00 2001 From: spapa013 Date: Tue, 9 Jan 2024 09:54:43 -0600 Subject: [PATCH 2/2] update version --- datajoint_plus/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint_plus/version.py b/datajoint_plus/version.py index f1380ee..9cb17e7 100644 --- a/datajoint_plus/version.py +++ b/datajoint_plus/version.py @@ -1 +1 @@ -__version__ = "0.1.7" +__version__ = "0.1.8"