diff --git a/drf_util/dsl_serializers/__init__.py b/drf_util/dsl_serializers/__init__.py new file mode 100644 index 0000000..7de7132 --- /dev/null +++ b/drf_util/dsl_serializers/__init__.py @@ -0,0 +1,8 @@ +from pkg_resources import require + +from .fields import * +from .serializers import * + + +require('elasticsearch') +require('elasticsearch_dsl') diff --git a/drf_util/dsl_serializers/fields.py b/drf_util/dsl_serializers/fields.py new file mode 100644 index 0000000..64a3e8c --- /dev/null +++ b/drf_util/dsl_serializers/fields.py @@ -0,0 +1,234 @@ +from abc import abstractmethod +from typing import Any, Type, Iterable + +from elasticsearch_dsl.query import Query, Term, Terms +from elasticsearch_dsl.search import Search + +from rest_framework.fields import ( + Field, + BooleanField, + NullBooleanField, + CharField, + EmailField, + RegexField, + SlugField, + URLField, + UUIDField, + IPAddressField, + IntegerField, + FloatField, + DecimalField, + DateTimeField, + DateField, + TimeField, + DurationField, + ChoiceField, + MultipleChoiceField, + ListField, +) + + +__all__ = [ + 'DslField', 'DslSortField', 'DslQueryField', 'DslSourceField', + 'CharSortField', 'ChoiceSortField', 'MultipleChoiceSortField', 'CharListSortField', + 'BooleanQueryField', 'NullBooleanQueryField', + 'CharQueryField', 'EmailQueryField', 'RegexQueryField', 'SlugQueryField', + 'URLQueryField', 'UUIDQueryField', 'IPAddressQueryField', + 'IntegerQueryField', 'FloatQueryField', 'DecimalQueryField', + 'DateTimeQueryField', 'DateQueryField', 'TimeQueryField', 'DurationQueryField', + 'CharListQueryField', 'IntegerListQueryField', 'FloatListQueryField', + 'CharSourceField', 'ChoiceSourceField', 'MultipleChoiceSourceField', 'CharListSourceField', +] + + +class DslField(Field): + @abstractmethod + def get_search(self, value: Any, search: Search = None, *args, **kwargs) -> Search: + pass + + +class DslSortField(DslField): + def get_keys(self, value, *args, **kwargs): + return value + + def get_search(self, value, search=None, *args, **kwargs): + fields = self.get_keys(value, *args, **kwargs) + if not isinstance(fields, Iterable): + raise ValueError(f"Method 'get_keys' expected list of string but got {fields}") + + search = search or Search() + if not isinstance(search, Search): + raise ValueError(f"Argument 'search' expected instance of Search but got {search}") + + return search.sort(*list(fields)) + + +class DslQueryField(DslField): + doc_field: str + dsl_query: Type[Query] + + def __init__(self, doc_field=None, dsl_query=None, *args, **kwargs): + super(DslQueryField, self).__init__(*args, **kwargs) + self.doc_field = doc_field or self.doc_field + self.dsl_query = dsl_query or self.dsl_query + + doc_field = doc_field or self.doc_field + if not isinstance(doc_field, str): + raise ValueError(f'doc_field should be instance of str but got {type(doc_field)}') + + dsl_query = dsl_query or self.dsl_query + if not issubclass(dsl_query, Query): + raise ValueError(f'dsl_query should be type of Query but got {dsl_query}') + + def get_query(self, value, *args, **kwargs): + query = self.dsl_query(**{self.doc_field: value}) + return query + + def get_search(self, value, search=None, *args, **kwargs): + query = self.get_query(value, *args, **kwargs) + if not isinstance(query, Query): + raise ValueError(f"Method 'get_query' expected instance of Query but got {query}") + + search = search or Search() + if not isinstance(search, Search): + raise ValueError(f"Argument 'search' expected instance of Search but got {search}") + + return search.query(query) + + +class DslSourceField(DslField): + def get_fields(self, value, *args, **kwargs): + return value + + def get_search(self, value, search=None, *args, **kwargs): + fields = self.get_fields(value, *args, **kwargs) + if not isinstance(fields, Iterable): + raise ValueError(f"Method 'get_keys' expected list of string but got {fields}") + + search = search or Search() + if not isinstance(search, Search): + raise ValueError(f"Argument 'search' expected instance of Search but got {search}") + + return search.source(fields=list(fields)) + + +# DslSortField + +class CharSortField(CharField, DslSortField): + def get_keys(self, value, *args, **kwargs): + return [value] + + +class ChoiceSortField(ChoiceField, DslSortField): + def get_keys(self, value, *args, **kwargs): + return [value] + + +class MultipleChoiceSortField(MultipleChoiceField, DslSortField): + pass + + +class CharListSortField(ListField, DslSortField): + child = CharField() + + +# DslQueryField + + +class BooleanQueryField(BooleanField, DslQueryField): + dsl_query = Term + + +class NullBooleanQueryField(NullBooleanField, DslQueryField): + dsl_query = Term + + +class CharQueryField(CharField, DslQueryField): + dsl_query = Term + + +class EmailQueryField(EmailField, DslQueryField): + dsl_query = Term + + +class RegexQueryField(RegexField, DslQueryField): + dsl_query = Term + + +class SlugQueryField(SlugField, DslQueryField): + dsl_query = Term + + +class URLQueryField(URLField, DslQueryField): + dsl_query = Term + + +class UUIDQueryField(UUIDField, DslQueryField): + dsl_query = Term + + +class IPAddressQueryField(IPAddressField, DslQueryField): + dsl_query = Term + + +class IntegerQueryField(IntegerField, DslQueryField): + dsl_query = Term + + +class FloatQueryField(FloatField, DslQueryField): + dsl_query = Term + + +class DecimalQueryField(DecimalField, DslQueryField): + dsl_query = Term + + +class DateTimeQueryField(DateTimeField, DslQueryField): + dsl_query = Term + + +class DateQueryField(DateField, DslQueryField): + dsl_query = Term + + +class TimeQueryField(TimeField, DslQueryField): + dsl_query = Term + + +class DurationQueryField(DurationField, DslQueryField): + dsl_query = Term + + +class CharListQueryField(ListField, DslQueryField): + dsl_query = Terms + child = CharField() + + +class IntegerListQueryField(ListField, DslQueryField): + dsl_query = Terms + child = IntegerField() + + +class FloatListQueryField(ListField, DslQueryField): + dsl_query = Terms + child = FloatField() + + +# DslSourceField + +class CharSourceField(CharField, DslSourceField): + def get_fields(self, value, *args, **kwargs): + return [value] + + +class ChoiceSourceField(ChoiceField, DslSourceField): + def get_fields(self, value, *args, **kwargs): + return [value] + + +class MultipleChoiceSourceField(MultipleChoiceField, DslSourceField): + pass + + +class CharListSourceField(ListField, DslSourceField): + child = CharField() diff --git a/drf_util/dsl_serializers/serializers.py b/drf_util/dsl_serializers/serializers.py new file mode 100644 index 0000000..af75d62 --- /dev/null +++ b/drf_util/dsl_serializers/serializers.py @@ -0,0 +1,106 @@ +from typing import Type +from datetime import datetime, time + +from rest_framework.fields import empty, IntegerField, FloatField, DateTimeField, DateField, TimeField +from rest_framework.serializers import Serializer + +from elasticsearch_dsl.query import Query, Range +from elasticsearch_dsl.search import Search + +from .fields import DslField, DslQueryField + + +__all__ = [ + 'DslSerializer', 'DslQuerySerializer', + 'IntegerRangeQueryField', 'FloatRangeQueryField', + 'DateTimeRangeQueryField', 'DateRangeQueryField', +] + + +class DslSerializer(Serializer): + def create(self, validated_data): + raise RuntimeError() + + def update(self, instance, validated_data): + raise RuntimeError() + + def get_search(self, search=None, *args, **kwargs) -> Search: + values = self.validated_data + fields = self.fields + + search = search or Search() + if not isinstance(search, Search): + raise ValueError(f"Argument 'search' expected instance of Search but got {search}") + + for name, field in fields.items(): + if not isinstance(field, DslField): + continue + + value = field.get_value(values) + if value is empty: + continue + + search = field.get_search(value, search, *args, **kwargs) + + return search + + +class DslQuerySerializer(Serializer, DslQueryField): + dsl_query: Type[Query] + doc_field: str + + def create(self, validated_data): + raise RuntimeError() + + def update(self, instance, validated_data): + raise RuntimeError() + + +class IntegerRangeQueryField(DslQuerySerializer): + dsl_query = Range + + gt = IntegerField(required=False) + lt = IntegerField(required=False) + gte = IntegerField(required=False) + lte = IntegerField(required=False) + + +class FloatRangeQueryField(DslQuerySerializer): + dsl_query = Range + + gt = FloatField(required=False) + lt = FloatField(required=False) + gte = FloatField(required=False) + lte = FloatField(required=False) + + +class DateTimeRangeQueryField(DslQuerySerializer): + dsl_query = Range + + gt = DateTimeField(required=False) + lt = DateTimeField(required=False) + gte = DateTimeField(required=False) + lte = DateTimeField(required=False) + + +class DateRangeQueryField(DslQuerySerializer): + dsl_query = Range + + gt = DateField(required=False) + lt = DateField(required=False) + gte = DateField(required=False) + lte = DateField(required=False) + + def to_internal_value(self, data): + value = super(DateRangeQueryField, self).to_internal_value(data) + + if 'gt' in value: + value['gt'] = datetime.combine(date=value['gt'], time=time.min) + if 'lt' in value: + value['lt'] = datetime.combine(date=value['lt'], time=time.max) + if 'gte' in value: + value['gte'] = datetime.combine(date=value['gte'], time=time.min) + if 'lte' in value: + value['lte'] = datetime.combine(date=value['lte'], time=time.max) + + return value diff --git a/setup.py b/setup.py index ec06570..32ced71 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name='drf_util', - version='0.0.28', + version='0.0.29', description='Django Rest Framework Utils', long_description=long_description, long_description_content_type="text/markdown", diff --git a/tests/test_dsl_serializers.py b/tests/test_dsl_serializers.py new file mode 100644 index 0000000..b7eaa7d --- /dev/null +++ b/tests/test_dsl_serializers.py @@ -0,0 +1,280 @@ +from datetime import datetime, date, time + +from elasticsearch_dsl.query import Match +from django.test import TestCase + +from drf_util.dsl_serializers import ( + DslSerializer, + + CharSortField, + ChoiceSortField, + MultipleChoiceSortField, + CharListSortField, + BooleanQueryField, + NullBooleanQueryField, + CharQueryField, + EmailQueryField, + RegexQueryField, + SlugQueryField, + URLQueryField, + UUIDQueryField, + IPAddressQueryField, + IntegerQueryField, + FloatQueryField, + DecimalQueryField, + DateTimeQueryField, + DateQueryField, + TimeQueryField, + CharListQueryField, + IntegerListQueryField, + FloatListQueryField, + CharSourceField, + ChoiceSourceField, + MultipleChoiceSourceField, + CharListSourceField, +) + + +class TestDslSerializer(TestCase): + def test_serializer_a(self): + class TestSerializer(DslSerializer): + source = CharSourceField() + query = CharQueryField(doc_field='doc_field') + sort = CharSortField() + + data = { + 'source': 'doc_field', + 'query': 'value', + 'sort': 'doc_field' + } + + serializer = TestSerializer(data=data) + serializer.is_valid() + search = serializer.get_search() + + expected = { + 'query': {'term': {'doc_field': 'doc_field_value'}}, + 'sort': ['doc_field'], + '_source': ['doc_field'] + } + + self.assertEqual(expected, search.to_dict()) + + def test_serializer_b(self): + class TestSerializer(DslSerializer): + source = CharListSourceField() + query = CharListQueryField(doc_field='doc_field') + sort = CharListSortField() + + data = { + 'source': ['doc_field_1', 'doc_field_2'], + 'query': ['value_1', 'value_2'], + 'sort': ['doc_field_1', '-doc_field_2'] + } + + serializer = TestSerializer(data=data) + serializer.is_valid() + search = serializer.get_search() + + expected = { + 'query': {'terms': {'doc_field': ['value_1', 'value_2']}}, + 'sort': ['doc_field_1', {'doc_field_2': {'order': 'desc'}}], + '_source': ['doc_field_1', 'doc_field_2'] + } + + self.assertEqual(expected, search.to_dict()) + + def test_serializer_c(self): + class TestSerializer(DslSerializer): + source = CharSourceField(required=False) + query = CharQueryField(doc_field='doc_field', dsl_query=Match) + sort = CharSortField(default='-date_created') + + data = { + 'query': 'value', + } + + serializer = TestSerializer(data=data) + serializer.is_valid() + search = serializer.get_search() + + expected = { + 'query': {'match': {'doc_field': 'value'}}, + 'sort': [{'date_created': {'order': 'desc'}}] + } + + self.assertEqual(expected, search.to_dict()) + + +class TestDslSortFields(TestCase): + def test_char_sort_field_asc(self): + search = CharSortField().get_search('doc_field') + expected = {'sort': ['doc_field']} + self.assertEqual(expected, search.to_dict()) + + def test_char_sort_field_desc(self): + search = CharSortField().get_search('-doc_field') + expected = {'sort': [{'doc_field': {'order': 'desc'}}]} + self.assertEqual(expected, search.to_dict()) + + def test_choice_sort_field_asc(self): + search = ChoiceSortField(choices=[]).get_search('doc_field') + expected = {'sort': ['doc_field']} + self.assertEqual(expected, search.to_dict()) + + def test_choice_sort_field_desc(self): + search = ChoiceSortField(choices=[]).get_search('-doc_field') + expected = {'sort': [{'doc_field': {'order': 'desc'}}]} + self.assertEqual(expected, search.to_dict()) + + def test_multiple_choice_sort_field_asc(self): + search = MultipleChoiceSortField(choices=[]).get_search(['doc_field_1', 'doc_field_2']) + expected = {'sort': ['doc_field_1', 'doc_field_2']} + self.assertEqual(expected, search.to_dict()) + + def test_multiple_choice_sort_field_desc(self): + search = MultipleChoiceSortField(choices=[]).get_search(['-doc_field_1', '-doc_field_2']) + expected = {'sort': [{'doc_field_1': {'order': 'desc'}}, {'doc_field_2': {'order': 'desc'}}]} + self.assertEqual(expected, search.to_dict()) + + def test_char_list_sort_field_asc(self): + search = CharListSortField().get_search(['doc_field_1', 'doc_field_2']) + expected = {'sort': ['doc_field_1', 'doc_field_2']} + self.assertEqual(expected, search.to_dict()) + + def test_char_list_sort_field_desc(self): + search = CharListSortField().get_search(['-doc_field_1', '-doc_field_2']) + expected = {'sort': [{'doc_field_1': {'order': 'desc'}}, {'doc_field_2': {'order': 'desc'}}]} + self.assertEqual(expected, search.to_dict()) + + +class TestDslQueryFields(TestCase): + def test_boolean_query_field(self): + value = True + search = BooleanQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_null_boolean_query_field(self): + value = None + search = NullBooleanQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_char_query_field(self): + value = 'string' + search = CharQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_email_query_field(self): + value = 'drf_serializers@drf_util.ebs' + search = EmailQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_regex_query_field(self): + value = 'string' + search = RegexQueryField(doc_field='doc_field', regex='*.').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_slug_query_field(self): + value = 'slug' + search = SlugQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_url_query_field(self): + value = 'http://drf_util.ebs/drf_serializer' + search = URLQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_uuid_query_field(self): + value = 'sdf87f5ad8f76fd87' + search = UUIDQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_ip_address_query_field(self): + value = '127.0.0.1' + search = IPAddressQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_integer_query_field(self): + value = 111 + search = IntegerQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_float_query_field(self): + value = 1.1 + search = FloatQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_decimal_query_field(self): + value = 1000.0001 + search = DecimalQueryField(doc_field='doc_field', decimal_places=4, max_digits=8).get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_date_time_query_field(self): + value = datetime.now() + search = DateTimeQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_date_query_field(self): + value = date.today() + search = DateQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_time_query_field(self): + value = time(10, 10, 10, 1000) + search = TimeQueryField(doc_field='doc_field').get_search(value) + expected = {'term': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_char_list_query_field(self): + value = ['a', 'b', 'c'] + search = CharListQueryField(doc_field='doc_field').get_search(value) + expected = {'terms': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_integer_list_query_field(self): + value = [1, 2, 3] + search = IntegerListQueryField(doc_field='doc_field').get_search(value) + expected = {'terms': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + def test_float_list_query_field(self): + value = [1.1, 2.2, 3.3] + search = FloatListQueryField(doc_field='doc_field').get_search(value) + expected = {'terms': {'doc_field': value}} + self.assertEqual(expected, search.to_dict()) + + +class TestDslSourceFields(TestCase): + def test_char_source_field(self): + search = CharSourceField().get_search('doc_field') + expected = {'_source': ['doc_field']} + self.assertEqual(expected, search.to_dict()) + + def test_choice_source_field(self): + search = ChoiceSourceField(choices=[]).get_search('doc_field') + expected = {'_source': ['doc_field']} + self.assertEqual(expected, search.to_dict()) + + def test_multiple_choice_source_field(self): + search = MultipleChoiceSourceField(choices=[]).get_search(['doc_field_1', 'doc_field_2']) + expected = {'_source': ['doc_field_1', 'doc_field_2']} + self.assertEqual(expected, search.to_dict()) + + def test_char_list_source_field(self): + search = CharListSourceField().get_search(['doc_field_1', 'doc_field_2']) + expected = {'_source': ['doc_field_1', 'doc_field_2']} + self.assertEqual(expected, search.to_dict())