Skip to content
This repository has been archived by the owner on Mar 21, 2019. It is now read-only.

Commit

Permalink
Merge pull request #353 from uisautomation/full-text-search
Browse files Browse the repository at this point in the history
Add full-text search for media items
  • Loading branch information
abrahammartin authored Oct 13, 2018
2 parents e7e6754 + 4476a39 commit fce9f8f
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 29 deletions.
82 changes: 82 additions & 0 deletions api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,88 @@ def test_token_auth_list(self):
for item in response_data['results']:
self.assertIn(item['id'], expected_ids)

def test_search_by_title(self):
"""Items can be searched by title."""
item = mpmodels.MediaItem.objects.first()
item.title = 'some bananas'
item.view_permission.is_public = True
item.view_permission.save()
item.save()
self.assert_search_result(item, positive_query='Banana', negative_query='Pineapple')

def test_search_by_description(self):
"""Items can be searched by description."""
item = mpmodels.MediaItem.objects.first()
item.description = 'some bananas'
item.view_permission.is_public = True
item.view_permission.save()
item.save()
self.assert_search_result(item, positive_query='Banana', negative_query='Pineapple')

def test_search_by_tags(self):
"""Items can be searched by tags."""
item = mpmodels.MediaItem.objects.first()
item.tags = ['apples', 'oranges', 'top bananas']
item.view_permission.is_public = True
item.view_permission.save()
item.save()
self.assert_search_result(item, positive_query='Banana', negative_query='Pineapple')
self.assert_search_result(item, positive_query='Banana', negative_query='Pineapple')

def test_search_ordering(self):
"""Items are sorted by relevance from search endpoint."""
items = mpmodels.MediaItem.objects.all()[:2]
for item in items:
item.view_permission.is_public = True
item.view_permission.save()

items[0].title = 'banana-y bananas are completely bananas'
items[0].save()
items[1].title = 'some bananas'
items[1].save()

# both items should appear in results
for item in items:
self.assert_search_result(item, positive_query='Banana')

# item 0 should be first
results = self.get_search_results('Banana')
self.assertEqual(results[0]['id'], items[0].id)

# make item 1 more relevant
items[1].description = (
'Bananas with bananas can banana the banana. Bruce Banana is not the Hulk')
items[1].save()

# both items should still appear in results
for item in items:
self.assert_search_result(item, positive_query='Banana')

# item 1 should be first
results = self.get_search_results('Banana')
self.assertEqual(results[0]['id'], items[1].id)

def assert_search_result(self, item, positive_query=None, negative_query=None):
# Item should appear in relevant query
if positive_query is not None:
self.assertTrue(any(
result_item['id'] == item.id
for result_item in self.get_search_results(positive_query)
))

# Item should not appear in irrelevant query
if negative_query is not None:
self.assertFalse(any(
result_item['id'] == item.id
for result_item in self.get_search_results(negative_query)
))

def get_search_results(self, query):
# this doesn't escape query which means tests should be kind in what they pass in here :)
get_request = self.factory.get('/?search=' + query)
response_data = self.view(get_request).data
return response_data['results']


class MediaItemViewTestCase(ViewTestCase):
def setUp(self):
Expand Down
74 changes: 45 additions & 29 deletions api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import automationlookup
from django.conf import settings
from django.contrib.postgres.search import SearchRank, SearchQuery
from django.db import models
from django.http import Http404
from django.shortcuts import redirect, render
Expand Down Expand Up @@ -37,6 +38,42 @@ class ListPagination(pagination.CursorPagination):
page_size = 50


class FullTextSearchFilter(filters.SearchFilter):
"""
Custom filter based on :py:class:`rest_framework.filters.SearchFilter` specialised to search
object with a full-text search SearchVectorField. Unlike the standard search filter, this class
accepts only one search field to be set in search_fields.
The filter *always* annotates the objects with a search rank. The name is "search_rank" by
default but can be overridden by setting search_rank_annotation on the view. If the search is
empty then this rank will always be zero.
"""
def filter_queryset(self, request, queryset, view):
search_terms = self.get_search_terms(request)
search_fields = getattr(view, 'search_fields', None)
search_rank_annotation = getattr(view, 'search_rank_annotation', 'search_rank')

if search_fields and len(search_fields) > 1:
raise ValueError('Can only handle a single search field')

# If there are no search terms, shortcut the search to return the entire query set but
# annotate it with a fake rank.
if not search_fields or not search_terms:
return queryset.annotate(**{
search_rank_annotation: models.Value(0, output_field=models.FloatField())
})

# Otherwise, form a query which is the logical OR of all the query terms.
query = SearchQuery(search_terms[0])
for t in search_terms[1:]:
query = query | SearchQuery(t)

return queryset.annotate(**{
search_rank_annotation: SearchRank(models.F(search_fields[0]), query)
}).filter(**{search_fields[0]: query})


class ViewMixinBase:
"""
A generic mixin class for API views which provides helper methods to filter querysets of
Expand Down Expand Up @@ -171,31 +208,6 @@ def get_object(self):
return self.get_profile()


class MediaItemListSearchFilter(filters.SearchFilter):
"""
Custom filter based on :py:class:`rest_framework.filters.SearchFilter` specialised to search
:py:class:`mediaplatform.models.MediaItem` objects. If the "tags" field is specified in the
view's ``search_fields`` attribute, then the tags field is searched for any tag matching the
lower cased search term.
"""

def get_search_term(self, request):
return request.query_params.get(self.search_param, '')

def get_search_terms(self, request):
return [self.get_search_term(request)]

def filter_queryset(self, request, queryset, view):
filtered_qs = super().filter_queryset(request, queryset, view)

if 'tags' in getattr(view, 'search_fields', ()):
search_term = self.get_search_term(request)
filtered_qs |= queryset.filter(tags__contains=[search_term.lower()])

return filtered_qs


class MediaItemListMixin(ViewMixinBase):
"""
A mixin class for DRF generic views which has all of the specialisations necessary for listing
Expand Down Expand Up @@ -251,15 +263,19 @@ def filter_playlist(self, queryset, name, value):

class MediaItemListView(MediaItemListMixin, generics.ListCreateAPIView):
"""
Endpoint to retrieve a list of media.
List and search Media items. If no other ordering is specified, results are returned in order
of decreasing search relevance (if there is any search) and then by decreasing publication
date.
"""
filter_backends = (filters.OrderingFilter, MediaItemListSearchFilter,
filter_backends = (filters.OrderingFilter, FullTextSearchFilter,
df_filters.DjangoFilterBackend)
ordering = '-publishedAt'
# The default ordering is by search rank first and then publication date. If no search is used,
# the rank is a fixed value and the publication date dominates.
ordering = ('-search_rank', '-publishedAt')
ordering_fields = ('publishedAt', 'updatedAt')
pagination_class = ListPagination
search_fields = ('title', 'description', 'tags')
search_fields = ('text_search_vector',)
serializer_class = serializers.MediaItemSerializer
filterset_class = MediaItemFilter

Expand Down
79 changes: 79 additions & 0 deletions mediaplatform/migrations/0018_add_full_text_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import django.contrib.postgres.indexes
import django.contrib.postgres.search as pgsearch
from django.db import migrations


# Raw SQL which creates a trigger which ensures that the text search vector is updated when the
# dependent fields update.
CREATE_TRIGGER_SQL = [
# A function intended to be run as a trigger on the mediaplatform.MediaItem table which will
# update the text search vector field to contain a concatenation of the title, description and
# tags. All of these are separated by spaces.
r'''
CREATE FUNCTION mediaplatform_mediaitem_tsvectorupdate_trigger() RETURNS trigger AS $$
begin
new.text_search_vector :=
to_tsvector('pg_catalog.english', concat(
coalesce(new.title, ''),
' ',
coalesce(new.description, ''),
' ',
array_to_string(new.tags, ' ')
));
return new;
end
$$ LANGUAGE plpgsql;
''',

# A trigger on the mediaplatform.MediaItem table which updates the text search vector field if
# the title, description or tags change or if a new row is inserted.
r'''
CREATE
TRIGGER mediaplatform_mediaitem_tsvectorupdate
BEFORE
INSERT OR UPDATE OF title, description, tags
ON
mediaplatform_mediaitem
FOR EACH ROW
EXECUTE PROCEDURE mediaplatform_mediaitem_tsvectorupdate_trigger();
''',

# Perform a trivial update of the mediaplatform.MediaItem table to cause the trigger to be run
# for each row.
r'''
UPDATE mediaplatform_mediaitem SET title=title;
''',
]

# Drop the trigger and trigger function created by CREATE_TRIGGER_SQL.
DROP_TRIGGER_SQL = [
r'''
DROP TRIGGER mediaplatform_mediaitem_tsvectorupdate ON mediaplatform_mediaitem;
''',
r'''
DROP FUNCTION mediaplatform_mediaitem_tsvectorupdate_trigger;
''',
]


class Migration(migrations.Migration):

dependencies = [
('mediaplatform', '0017_index_media_item_updated_at_and_published_at'),
]

operations = [
migrations.AddField(
model_name='mediaitem',
name='text_search_vector',
field=django.contrib.postgres.search.SearchVectorField(default=''),
preserve_default=False,
),
migrations.AddIndex(
model_name='mediaitem',
index=django.contrib.postgres.indexes.GinIndex(
fields=['text_search_vector'], name='mediaplatfo_text_se_d418e1_gin'
),
),
migrations.RunSQL(CREATE_TRIGGER_SQL, DROP_TRIGGER_SQL),
]
6 changes: 6 additions & 0 deletions mediaplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import automationlookup
from django.conf import settings
import django.contrib.postgres.fields as pgfields
import django.contrib.postgres.indexes as pgindexes
import django.contrib.postgres.search as pgsearch
from django.db import models
from django.db.models import Q, expressions, functions
from django.db.models.signals import post_save
Expand Down Expand Up @@ -282,6 +284,7 @@ class Meta:
indexes = (
models.Index(fields=['updated_at']),
models.Index(fields=['published_at']),
pgindexes.GinIndex(fields=['text_search_vector']),
)

VIDEO = 'video'
Expand Down Expand Up @@ -350,6 +353,9 @@ class Meta:
tags = pgfields.ArrayField(models.CharField(max_length=256), default=_blank_array, blank=True,
help_text='Tags/keywords for item')

#: Full text search vector field
text_search_vector = pgsearch.SearchVectorField()

#: Creation time
created_at = models.DateTimeField(auto_now_add=True)

Expand Down

0 comments on commit fce9f8f

Please sign in to comment.