diff --git a/mygpo/administration/group.py b/mygpo/administration/group.py
index 0adaf98b6..95729a1ef 100644
--- a/mygpo/administration/group.py
+++ b/mygpo/administration/group.py
@@ -29,15 +29,18 @@ def __get_episodes(self):
def group(self, get_features):
+ """ Groups the episodes by features extracted using ``get_features``
+
+ get_features is a callable that expects an episode as parameter, and
+ returns a value representing the extracted feature(s).
+ """
episodes = self.__get_episodes()
episode_groups = defaultdict(list)
- episode_features = map(get_features, episodes.items())
-
- for features, episode_id in episode_features:
- episode = episodes[episode_id]
+ for episode in episodes.values():
+ features = get_features(episode)
episode_groups[features].append(episode)
groups = sorted(episode_groups.values(), key=_SORT_KEY)
diff --git a/mygpo/administration/templates/admin/merge-grouping.html b/mygpo/administration/templates/admin/merge-grouping.html
index 99bbcc526..dc4354884 100644
--- a/mygpo/administration/templates/admin/merge-grouping.html
+++ b/mygpo/administration/templates/admin/merge-grouping.html
@@ -47,7 +47,7 @@
{% trans "Merge Podcasts and Episodes" %}
{% for episode in episodes %}
{% if episode.podcast.get_id == podcast.get_id %}
-
+
{% episode_link episode podcast %}
{% endif %}
{% endfor %}
diff --git a/mygpo/administration/views.py b/mygpo/administration/views.py
index 2c42d83c5..27e93e257 100644
--- a/mygpo/administration/views.py
+++ b/mygpo/administration/views.py
@@ -140,11 +140,10 @@ def post(self, request):
grouper = PodcastGrouper(podcasts)
- get_features = lambda id_e: ((id_e[1].url, id_e[1].title), id_e[0])
+ get_features = lambda episode: (episode.url, episode.title)
num_groups = grouper.group(get_features)
-
except InvalidPodcast as ip:
messages.error(request,
_('No podcast with URL {url}').format(url=str(ip)))
@@ -178,10 +177,10 @@ def post(self, request):
for key, feature in request.POST.items():
m = self.RE_EPISODE.match(key)
if m:
- episode_id = m.group(1)
+ episode_id = uuid.UUID(m.group(1))
features[episode_id] = feature
- get_features = lambda id_e: (features.get(id_e[0], id_e[0]), id_e[0])
+ get_features = lambda episode: features[episode.id]
num_groups = grouper.group(get_features)
queue_id = request.POST.get('queue_id', '')
diff --git a/mygpo/maintenance/merge.py b/mygpo/maintenance/merge.py
index d8d950cef..7c435d95a 100644
--- a/mygpo/maintenance/merge.py
+++ b/mygpo/maintenance/merge.py
@@ -12,6 +12,7 @@
from mygpo.history.models import HistoryEntry, EpisodeHistoryEntry
from mygpo.publisher.models import PublishedPodcast
from mygpo.subscriptions.models import Subscription
+from . import models
import logging
logger = logging.getLogger(__name__)
@@ -68,7 +69,7 @@ def merge_episodes(self):
# based on https://djangosnippets.org/snippets/2283/
@transaction.atomic
-def merge_model_objects(primary_object, alias_objects=[], keep_old=False):
+def merge_model_objects(primary_object, alias_objects, keep_old=False):
"""
Use this function to merge model objects (i.e. Users, Organizations, Polls,
etc.) and migrate all of the related fields from the alias objects to the
@@ -78,10 +79,8 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False):
from django.contrib.auth.models import User
primary_user = User.objects.get(email='good_email@example.com')
duplicate_user = User.objects.get(email='good_email+duplicate@example.com')
- merge_model_objects(primary_user, duplicate_user)
+ merge_model_objects(primary_user, [duplicate_user])
"""
- if not isinstance(alias_objects, list):
- alias_objects = [alias_objects]
# check that all aliases are the same class as primary one and that
# they are subclass of model
@@ -105,11 +104,6 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False):
for field_name, field in fields:
generic_fields.append(field)
- blank_local_fields = set(
- [field.attname for field
- in primary_object._meta.local_fields
- if getattr(primary_object, field.attname) in [None, '']])
-
# Loop through all alias objects and migrate their data to
# the primary object.
for alias_object in alias_objects:
@@ -123,8 +117,9 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False):
related_objects = getattr(alias_object, alias_varname)
for obj in related_objects.all():
setattr(obj, obj_varname, primary_object)
- reassigned(obj, primary_object)
- obj.save()
+ deleted = reassigned(obj, primary_object)
+ if not deleted:
+ obj.save()
# Migrate all many to many references from alias object to
# primary object.
@@ -143,8 +138,9 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False):
obj_varname).all()
for obj in related_many_objects.all():
getattr(obj, obj_varname).remove(alias_object)
- reassigned(obj, primary_object)
- getattr(obj, obj_varname).add(primary_object)
+ deleted = reassigned(obj, primary_object)
+ if not deleted:
+ getattr(obj, obj_varname).add(primary_object)
# Migrate all generic foreign key references from alias
# object to primary object.
@@ -156,7 +152,10 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False):
related = field.model.objects.filter(**filter_kwargs)
for generic_related_object in related:
setattr(generic_related_object, field.name, primary_object)
- reassigned(generic_related_object, primary_object)
+ deleted = reassigned(generic_related_object, primary_object)
+ if deleted:
+ continue
+
try:
# execute save in a savepoint, so we can resume in the
# transaction
@@ -166,20 +165,10 @@ def merge_model_objects(primary_object, alias_objects=[], keep_old=False):
if ie.__cause__.pgcode == PG_UNIQUE_VIOLATION:
merge(generic_related_object, primary_object)
- # Try to fill all missing values in primary object by
- # values of duplicates
- filled_up = set()
- for field_name in blank_local_fields:
- val = getattr(alias_object, field_name)
- if val not in [None, '']:
- setattr(primary_object, field_name, val)
- filled_up.add(field_name)
- blank_local_fields -= filled_up
-
if not keep_old:
before_delete(alias_object, primary_object)
alias_object.delete()
- primary_object.save()
+
return primary_object
@@ -199,6 +188,17 @@ def _get_all_related_many_to_many_objects(obj):
def reassigned(obj, new):
+ """ handles changes necessary when reassigning `obj` to `new`
+
+ Some objects have a dependent object (eg URL has a Podcast or Episode.
+ During merging, the object might be assigned from to a new Episode.
+ The re-assignment requires the "scope" field to be set to the value
+ of the new episode. In some cases it might require the existing object to
+ be deleted, to preserve uniqueness.
+
+ Returns whether the object was deleted.
+ """
+
if isinstance(obj, URL):
# a URL has its parent's scope
obj.scope = new.scope
@@ -207,11 +207,27 @@ def reassigned(obj, new):
max_order = max([-1] + [u.order for u in existing_urls])
obj.order = max_order+1
+ elif isinstance(obj, Slug):
+ # a Slug has its parent's scope
+ obj.scope = new.scope
+
+ existing_slugs = new.slugs.all()
+ max_order = max([-1] + [s.order for s in existing_slugs])
+ obj.order = max_order+1
+
elif isinstance(obj, Episode):
# obj is an Episode, new is a podcast
for url in obj.urls.all():
url.scope = new.as_scope
- url.save()
+ try:
+ with transaction.atomic():
+ url.save()
+ except IntegrityError as ie:
+ if 'podcasts_url_url_scope_key' in str(ie):
+ url.delete()
+ return True
+ else:
+ raise
elif isinstance(obj, Subscription):
pass
@@ -222,10 +238,17 @@ def reassigned(obj, new):
elif isinstance(obj, HistoryEntry):
pass
+ elif isinstance(obj, models.MergeQueueEntry):
+ obj.delete()
+ return True
+
else:
raise TypeError('unknown type for reassigning: {objtype}'.format(
objtype=type(obj)))
+ # Object was not deleted
+ return False
+
def before_delete(old, new):
diff --git a/mygpo/maintenance/models.py b/mygpo/maintenance/models.py
index 877f96ed9..17c28b5a2 100644
--- a/mygpo/maintenance/models.py
+++ b/mygpo/maintenance/models.py
@@ -7,6 +7,14 @@
class MergeQueue(UUIDModel):
""" A Group of podcasts that could be merged """
+ @property
+ def podcasts(self):
+ """ Returns the podcasts of the queue, sorted by subscribers """
+ podcasts = [entry.podcast for entry in self.entries.all()]
+ podcasts = sorted(podcasts,
+ key=lambda p: p.subscribers, reverse=True)
+ return podcasts
+
class MergeQueueEntry(UUIDModel):
""" An entry in a MergeQueue """