Skip to content
This repository has been archived by the owner on Mar 6, 2023. It is now read-only.

_clone_copy_fk added to allow copying many-to-one relationships. #7

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions django_cloneable/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,44 @@ def _clone_attrs(self, duplicate, attrs, exclude=None):
for attname, value in attrs.items():
setattr(duplicate, attname, value)

def _clone_copy_fk(self, duplicate, exclude=None):
exclude = exclude or []
foreign_keys = {}

for field in self.instance._meta.related_objects:
# Skip this field.
if field.name in exclude:
continue

# Check for one to many:
if field.one_to_many:
f_name = '%s_set' % field.name

# Collect the objects which contain ForeignKey pointing to the source object
fks_to_copy = list(getattr(self.instance, f_name).all())

for fk in fks_to_copy:
# Empty primary key
fk.pk = None

# Iterate fields in the classes which contain ForeignKey pointing to the source object.
# If the field has the same object as our source, we should rewrite it's value to point
# to our newly created duplicated record.
for fk_field in fk._meta.fields:
if fk_field.related_model:
if duplicate._meta.object_name == fk_field.related_model._meta.object_name:
setattr(fk, fk_field.name, duplicate)

try:
# Use fk.__class__ here to avoid hard-coding the class name
foreign_keys[fk.__class__].append(fk)
except KeyError:
foreign_keys[fk.__class__] = [fk]

# Insert the new records in the database
for cls, list_of_fks in foreign_keys.items():
cls.objects.bulk_create(list_of_fks)

def _clone_copy_m2m(self, duplicate, exclude=None):
exclude = exclude or []
# copy.copy loses all ManyToMany relations.
Expand Down Expand Up @@ -142,6 +180,11 @@ def clone(self, attrs=None, commit=True, m2m_clone_reverse=True,
clone_attrs = getattr(self.instance, '_clone_attrs', self._clone_attrs)
clone_attrs(duplicate, attrs, exclude=exclude)

def clone_fk():
clone_copy_fk = getattr(self.instance, '_clone_copy_fk',
self._clone_copy_fk)
clone_copy_fk(duplicate, exclude=exclude)

def clone_m2m(clone_reverse=m2m_clone_reverse):
clone_copy_m2m = getattr(self.instance, '_clone_copy_m2m',
self._clone_copy_m2m)
Expand All @@ -155,8 +198,10 @@ def clone_m2m(clone_reverse=m2m_clone_reverse):

if commit:
duplicate.save(force_insert=True)
clone_fk()
clone_m2m()
else:
duplicate.clone_fk = clone_fk
duplicate.clone_m2m = clone_m2m
return duplicate

Expand Down Expand Up @@ -203,6 +248,9 @@ def _clone_attrs(self, duplicate, attrs, exclude=None):
return self._clone_helper._clone_attrs(duplicate, attrs,
exclude=exclude)

def _clone_copy_fk(self, duplicate, exclude=None):
return self._clone_helper._clone_copy_fk(duplicate, exclude=exclude)

def _clone_copy_m2m(self, duplicate, exclude=None):
return self._clone_helper._clone_copy_m2m(duplicate, exclude=exclude)

Expand All @@ -219,6 +267,9 @@ def clone(self, attrs=None, commit=True, m2m_clone_reverse=True,
self._clone_prepare(duplicate, exclude=exclude)
self._clone_attrs(duplicate, attrs, exclude=exclude)

def clone_fk():
self._clone_copy_fk(duplicate, exclude=exclude)

def clone_m2m(clone_reverse=m2m_clone_reverse):
self._clone_copy_m2m(duplicate, exclude=exclude)
if clone_reverse:
Expand All @@ -227,8 +278,10 @@ def clone_m2m(clone_reverse=m2m_clone_reverse):

if commit:
duplicate.save(force_insert=True)
clone_fk()
clone_m2m()
else:
duplicate.clone_fk = clone_fk
duplicate.clone_m2m = clone_m2m

return duplicate
Expand Down