Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upsert #156

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,17 @@ Keyword Argument Description
parameters.

``ignore_conflicts`` Specify True to ignore unique constraint or exclusion
constraint violation errors. The default is False.
constraint violation errors. The default is False. This
is depreciated in favor of `on_conflict={'action': 'ignore'}`.

``on_conflict`` Specifies how PostgreSQL handles conflicts. For example,
`on_conflict={'action': 'ignore'}` will ignore any
conflicts. If setting `'action'` to `'update'`, you
must also specify `'target'` (the source of the
constraint: either a model field name, a constraint name,
or a list of model field names) as well as `'columns'`
(a list of model fields to update). The default is None,
which will raise conflict errors if they occur.

``using`` Sets the database to use when importing data.
Default is None, which will use the ``'default'``
Expand Down
87 changes: 78 additions & 9 deletions postgres_copy/copy_from.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
from collections import OrderedDict
from io import TextIOWrapper
import warnings

from django.contrib.humanize.templatetags.humanize import intcomma
from django.core.exceptions import FieldDoesNotExist
Expand All @@ -34,6 +35,7 @@ def __init__(
force_null=None,
encoding=None,
ignore_conflicts=False,
on_conflict={},
static_mapping=None,
temp_table_name=None,
):
Expand All @@ -58,8 +60,9 @@ def __init__(
self.force_not_null = force_not_null
self.force_null = force_null
self.encoding = encoding
self.supports_ignore_conflicts = True
self.supports_on_conflict = True
self.ignore_conflicts = ignore_conflicts
self.on_conflict = on_conflict
if static_mapping is not None:
self.static_mapping = OrderedDict(static_mapping)
else:
Expand All @@ -77,12 +80,18 @@ def __init__(
if self.conn.vendor != "postgresql":
raise TypeError("Only PostgreSQL backends supported")

# Check if it is PSQL 9.5 or greater, which determines if ignore_conflicts is supported
self.supports_ignore_conflicts = self.is_postgresql_9_5()
if self.ignore_conflicts and not self.supports_ignore_conflicts:
raise NotSupportedError(
"This database backend does not support ignoring conflicts."
# Check if it is PSQL 9.5 or greater, which determines if on_conflict is supported
self.supports_on_conflict = self.is_postgresql_9_5()
if self.ignore_conflicts:
self.on_conflict = {
'action': 'ignore',
}
warnings.warn(
"The `ignore_conflicts` kwarg has been replaced with "
"on_conflict={'action': 'ignore'}."
)
if self.on_conflict and not self.supports_on_conflict:
raise NotSupportedError('This database backend does not support conflict logic.')

# Pull the CSV headers
self.headers = self.get_headers()
Expand Down Expand Up @@ -319,10 +328,70 @@ def insert_suffix(self):
"""
Preps the suffix to the insert query.
"""
if self.ignore_conflicts:
if self.on_conflict:
try:
action = self.on_conflict['action']
except KeyError:
raise ValueError("Must specify an `action` when passing `on_conflict`.")
if action == 'ignore':
target, action = "", "DO NOTHING"
elif action == 'update':
try:
target = self.on_conflict['target']
except KeyError:
raise ValueError("Must specify `target` when action == 'update'.")
try:
columns = self.on_conflict['columns']
except KeyError:
raise ValueError("Must specify `columns` when action == 'update'.")

# As recommended in PostgreSQL's INSERT documentation, we use "index inference"
# rather than naming a constraint directly. Currently, if an `include` param
# is provided to a django.models.Constraint, Django creates a UNIQUE INDEX instead
# of a CONSTRAINT, another reason to use "index inference" by just specifying columns.
constraints = {c.name: c for c in self.model._meta.constraints}
if isinstance(target, str):
if constraints.get(target):
# Make sure to use db column names
target = [
self.get_field(field_name).column
for field_name in constraints.get(target).fields
]
else:
target = [target]
elif not isinstance(target, list):
raise ValueError("`target` must be a string or a list.")
target = "({0})".format(', '.join(target))

# Convert to db_column names
db_columns = [self.model._meta.get_field(col).column for col in columns]

# Get update_values from the `excluded` table
update_values = ', '.join([
"{0} = excluded.{0}".format(db_col)
for db_col in db_columns
])

# Only update the row if the values are different
model_table = self.model._meta.db_table
new_values = ', '.join([
model_table + '.' + db_col
for db_col in db_columns
])
old_values = ', '.join([
"excluded.{0}".format(db_col)
for db_col in db_columns
])
action = "DO UPDATE SET {0} WHERE ({1}) IS DISTINCT FROM ({2})".format(
update_values,
new_values,
old_values,
)
else:
raise ValueError("Action must be one of 'ignore' or 'update'.")
return """
ON CONFLICT DO NOTHING;
"""
ON CONFLICT {0} {1};
""".format(target, action)
else:
return ";"

Expand Down
51 changes: 46 additions & 5 deletions postgres_copy/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,20 @@ def drop_constraints(self):

# Remove any field constraints
for field in self.constrained_fields:
logger.debug(f"Dropping constraints from {field}")

logger.debug("Dropping field constraint from {}".format(field))

field_copy = field.__copy__()
field_copy.db_constraint = False
args = (self.model, field, field_copy)
self.edit_schema(schema_editor, "alter_field", args)

# Remove remaining constraints
for constraint in getattr(self.model._meta, 'constraints', []):
logger.debug("Dropping constraint '{}'".format(constraint.name))
args = (self.model, constraint)
self.edit_schema(schema_editor, 'remove_constraint', args)

def drop_indexes(self):
"""
Drop indexes on the model and its fields.
Expand All @@ -88,12 +96,20 @@ def drop_indexes(self):

# Remove any field indexes
for field in self.indexed_fields:
logger.debug(f"Dropping index from {field}")

logger.debug("Dropping field index from {}".format(field))

field_copy = field.__copy__()
field_copy.db_index = False
args = (self.model, field, field_copy)
self.edit_schema(schema_editor, "alter_field", args)

# Remove remaining indexes
for index in getattr(self.model._meta, 'indexes', []):
logger.debug("Dropping index '{}'".format(index.name))
args = (self.model, index)
self.edit_schema(schema_editor, 'remove_index', args)

def restore_constraints(self):
"""
Restore constraints on the model and its fields.
Expand All @@ -111,14 +127,22 @@ def restore_constraints(self):
args = (self.model, (), self.model._meta.unique_together)
self.edit_schema(schema_editor, "alter_unique_together", args)

# Add any constraints to the fields
# Add any field constraints
for field in self.constrained_fields:
logger.debug(f"Adding constraints to {field}")

logger.debug("Adding field constraint to {}".format(field))

field_copy = field.__copy__()
field_copy.db_constraint = False
args = (self.model, field_copy, field)
self.edit_schema(schema_editor, "alter_field", args)

# Add remaining constraints
for constraint in getattr(self.model._meta, 'constraints', []):
logger.debug("Adding constraint '{}'".format(constraint.name))
args = (self.model, constraint)
self.edit_schema(schema_editor, 'add_constraint', args)

def restore_indexes(self):
"""
Restore indexes on the model and its fields.
Expand All @@ -138,12 +162,20 @@ def restore_indexes(self):

# Add any indexes to the fields
for field in self.indexed_fields:
logger.debug(f"Restoring index to {field}")

logger.debug("Restoring field index to {}".format(field))

field_copy = field.__copy__()
field_copy.db_index = False
args = (self.model, field_copy, field)
self.edit_schema(schema_editor, "alter_field", args)

# Add remaining indexes
for index in getattr(self.model._meta, 'indexes', []):
logger.debug("Adding index '{}'".format(index.name))
args = (self.model, index)
self.edit_schema(schema_editor, 'add_index', args)


class CopyQuerySet(ConstraintQuerySet):
"""
Expand Down Expand Up @@ -178,6 +210,15 @@ def from_csv(
"drop_constraints=False and drop_indexes=False."
)

# NOTE: See GH Issue #117
# We could remove this block if drop_constraints' default was False
if kwargs.get('on_conflict'):
if kwargs['on_conflict'].get('target'):
if kwargs['on_conflict']['target'] in [c.name for c in self.model._meta.constraints]:
drop_constraints = False
elif kwargs['on_conflict'].get('action') == 'ignore':
drop_constraints = False

mapping = CopyMapping(self.model, csv_path, mapping, **kwargs)

if drop_constraints:
Expand Down
31 changes: 30 additions & 1 deletion tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,35 @@ class SecondaryMockObject(models.Model):
objects = CopyManager()


class UniqueMockObject(models.Model):
class UniqueFieldConstraintMockObject(models.Model):
name = models.CharField(max_length=500, unique=True)
objects = CopyManager()


class UniqueModelConstraintMockObject(models.Model):
name = models.CharField(max_length=500)
number = MyIntegerField(null=True, db_column='num')
objects = CopyManager()

class Meta:
constraints = [
models.UniqueConstraint(
name='constraint',
fields=['name'],
),
]


class UniqueModelConstraintAsIndexMockObject(models.Model):
name = models.CharField(max_length=500)
number = MyIntegerField(null=True, db_column='num')
objects = CopyManager()

class Meta:
constraints = [
models.UniqueConstraint(
name='constraint_as_index',
fields=['name'],
include=['number'], # Converts Constraint to Index
),
]
92 changes: 87 additions & 5 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
MockObject,
OverloadMockObject,
SecondaryMockObject,
UniqueMockObject,
UniqueFieldConstraintMockObject,
UniqueModelConstraintMockObject,
UniqueModelConstraintAsIndexMockObject,
)


Expand Down Expand Up @@ -516,11 +518,91 @@ def test_encoding_save(self, _):

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_ignore_conflicts(self, _):
UniqueMockObject.objects.from_csv(
self.name_path, dict(name="NAME"), ignore_conflicts=True
UniqueFieldConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME'),
ignore_conflicts=True
)
UniqueFieldConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME'),
ignore_conflicts=True
)

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_on_conflict_ignore(self, _):
UniqueModelConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER'),
on_conflict={'action': 'ignore'},
)
UniqueModelConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER'),
on_conflict={'action': 'ignore'},
)
UniqueMockObject.objects.from_csv(
self.name_path, dict(name="NAME"), ignore_conflicts=True

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_on_conflict_target_field_update(self, _):
UniqueFieldConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME'),
on_conflict={
'action': 'update',
'target': 'name',
'columns': ['name'],
},
)
UniqueFieldConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME'),
on_conflict={
'action': 'update',
'target': 'name',
'columns': ['name'],
},
)

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_on_conflict_target_constraint_update(self, _):
UniqueModelConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER'),
on_conflict={
'action': 'update',
'target': 'constraint',
'columns': ['name', 'number'],
},
)
UniqueModelConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER'),
on_conflict={
'action': 'update',
'target': 'constraint',
'columns': ['name', 'number'],
},
)

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_on_conflict_target_constraint_as_index_update(self, _):
UniqueModelConstraintAsIndexMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER'),
on_conflict={
'action': 'update',
'target': 'constraint_as_index',
'columns': ['name', 'number'],
},
)
UniqueModelConstraintAsIndexMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER'),
on_conflict={
'action': 'update',
'target': 'constraint_as_index',
'columns': ['name', 'number'],
},
)

@mock.patch("django.db.connection.validate_no_atomic_block")
Expand Down