Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.0.29 #15

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions drf_util/dsl_serializers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pkg_resources import require

from .fields import *
from .serializers import *


require('elasticsearch')
require('elasticsearch_dsl')
234 changes: 234 additions & 0 deletions drf_util/dsl_serializers/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
from abc import abstractmethod
from typing import Any, Type, Iterable

from elasticsearch_dsl.query import Query, Term, Terms
from elasticsearch_dsl.search import Search

from rest_framework.fields import (
Field,
BooleanField,
NullBooleanField,
CharField,
EmailField,
RegexField,
SlugField,
URLField,
UUIDField,
IPAddressField,
IntegerField,
FloatField,
DecimalField,
DateTimeField,
DateField,
TimeField,
DurationField,
ChoiceField,
MultipleChoiceField,
ListField,
)


__all__ = [
'DslField', 'DslSortField', 'DslQueryField', 'DslSourceField',
'CharSortField', 'ChoiceSortField', 'MultipleChoiceSortField', 'CharListSortField',
'BooleanQueryField', 'NullBooleanQueryField',
'CharQueryField', 'EmailQueryField', 'RegexQueryField', 'SlugQueryField',
'URLQueryField', 'UUIDQueryField', 'IPAddressQueryField',
'IntegerQueryField', 'FloatQueryField', 'DecimalQueryField',
'DateTimeQueryField', 'DateQueryField', 'TimeQueryField', 'DurationQueryField',
'CharListQueryField', 'IntegerListQueryField', 'FloatListQueryField',
'CharSourceField', 'ChoiceSourceField', 'MultipleChoiceSourceField', 'CharListSourceField',
]


class DslField(Field):
@abstractmethod
def get_search(self, value: Any, search: Search = None, *args, **kwargs) -> Search:
pass


class DslSortField(DslField):
def get_keys(self, value, *args, **kwargs):
return value

def get_search(self, value, search=None, *args, **kwargs):
fields = self.get_keys(value, *args, **kwargs)
if not isinstance(fields, Iterable):
raise ValueError(f"Method 'get_keys' expected list of string but got {fields}")

search = search or Search()
if not isinstance(search, Search):
raise ValueError(f"Argument 'search' expected instance of Search but got {search}")

return search.sort(*list(fields))


class DslQueryField(DslField):
doc_field: str
dsl_query: Type[Query]

def __init__(self, doc_field=None, dsl_query=None, *args, **kwargs):
super(DslQueryField, self).__init__(*args, **kwargs)
self.doc_field = doc_field or self.doc_field
self.dsl_query = dsl_query or self.dsl_query

doc_field = doc_field or self.doc_field
if not isinstance(doc_field, str):
raise ValueError(f'doc_field should be instance of str but got {type(doc_field)}')

dsl_query = dsl_query or self.dsl_query
if not issubclass(dsl_query, Query):
raise ValueError(f'dsl_query should be type of Query but got {dsl_query}')

def get_query(self, value, *args, **kwargs):
query = self.dsl_query(**{self.doc_field: value})
return query

def get_search(self, value, search=None, *args, **kwargs):
query = self.get_query(value, *args, **kwargs)
if not isinstance(query, Query):
raise ValueError(f"Method 'get_query' expected instance of Query but got {query}")

search = search or Search()
if not isinstance(search, Search):
raise ValueError(f"Argument 'search' expected instance of Search but got {search}")

return search.query(query)


class DslSourceField(DslField):
def get_fields(self, value, *args, **kwargs):
return value

def get_search(self, value, search=None, *args, **kwargs):
fields = self.get_fields(value, *args, **kwargs)
if not isinstance(fields, Iterable):
raise ValueError(f"Method 'get_keys' expected list of string but got {fields}")

search = search or Search()
if not isinstance(search, Search):
raise ValueError(f"Argument 'search' expected instance of Search but got {search}")

return search.source(fields=list(fields))


# DslSortField

class CharSortField(CharField, DslSortField):
def get_keys(self, value, *args, **kwargs):
return [value]


class ChoiceSortField(ChoiceField, DslSortField):
def get_keys(self, value, *args, **kwargs):
return [value]


class MultipleChoiceSortField(MultipleChoiceField, DslSortField):
pass


class CharListSortField(ListField, DslSortField):
child = CharField()


# DslQueryField


class BooleanQueryField(BooleanField, DslQueryField):
dsl_query = Term


class NullBooleanQueryField(NullBooleanField, DslQueryField):
dsl_query = Term


class CharQueryField(CharField, DslQueryField):
dsl_query = Term


class EmailQueryField(EmailField, DslQueryField):
dsl_query = Term


class RegexQueryField(RegexField, DslQueryField):
dsl_query = Term


class SlugQueryField(SlugField, DslQueryField):
dsl_query = Term


class URLQueryField(URLField, DslQueryField):
dsl_query = Term


class UUIDQueryField(UUIDField, DslQueryField):
dsl_query = Term


class IPAddressQueryField(IPAddressField, DslQueryField):
dsl_query = Term


class IntegerQueryField(IntegerField, DslQueryField):
dsl_query = Term


class FloatQueryField(FloatField, DslQueryField):
dsl_query = Term


class DecimalQueryField(DecimalField, DslQueryField):
dsl_query = Term


class DateTimeQueryField(DateTimeField, DslQueryField):
dsl_query = Term


class DateQueryField(DateField, DslQueryField):
dsl_query = Term


class TimeQueryField(TimeField, DslQueryField):
dsl_query = Term


class DurationQueryField(DurationField, DslQueryField):
dsl_query = Term


class CharListQueryField(ListField, DslQueryField):
dsl_query = Terms
child = CharField()


class IntegerListQueryField(ListField, DslQueryField):
dsl_query = Terms
child = IntegerField()


class FloatListQueryField(ListField, DslQueryField):
dsl_query = Terms
child = FloatField()


# DslSourceField

class CharSourceField(CharField, DslSourceField):
def get_fields(self, value, *args, **kwargs):
return [value]


class ChoiceSourceField(ChoiceField, DslSourceField):
def get_fields(self, value, *args, **kwargs):
return [value]


class MultipleChoiceSourceField(MultipleChoiceField, DslSourceField):
pass


class CharListSourceField(ListField, DslSourceField):
child = CharField()
106 changes: 106 additions & 0 deletions drf_util/dsl_serializers/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Type
from datetime import datetime, time

from rest_framework.fields import empty, IntegerField, FloatField, DateTimeField, DateField, TimeField
from rest_framework.serializers import Serializer

from elasticsearch_dsl.query import Query, Range
from elasticsearch_dsl.search import Search

from .fields import DslField, DslQueryField


__all__ = [
'DslSerializer', 'DslQuerySerializer',
'IntegerRangeQueryField', 'FloatRangeQueryField',
'DateTimeRangeQueryField', 'DateRangeQueryField',
]


class DslSerializer(Serializer):
def create(self, validated_data):
raise RuntimeError()

def update(self, instance, validated_data):
raise RuntimeError()

def get_search(self, search=None, *args, **kwargs) -> Search:
values = self.validated_data
fields = self.fields

search = search or Search()
if not isinstance(search, Search):
raise ValueError(f"Argument 'search' expected instance of Search but got {search}")

for name, field in fields.items():
if not isinstance(field, DslField):
continue

value = field.get_value(values)
if value is empty:
continue

search = field.get_search(value, search, *args, **kwargs)

return search


class DslQuerySerializer(Serializer, DslQueryField):
dsl_query: Type[Query]
doc_field: str

def create(self, validated_data):
raise RuntimeError()

def update(self, instance, validated_data):
raise RuntimeError()


class IntegerRangeQueryField(DslQuerySerializer):
dsl_query = Range

gt = IntegerField(required=False)
lt = IntegerField(required=False)
gte = IntegerField(required=False)
lte = IntegerField(required=False)


class FloatRangeQueryField(DslQuerySerializer):
dsl_query = Range

gt = FloatField(required=False)
lt = FloatField(required=False)
gte = FloatField(required=False)
lte = FloatField(required=False)


class DateTimeRangeQueryField(DslQuerySerializer):
dsl_query = Range

gt = DateTimeField(required=False)
lt = DateTimeField(required=False)
gte = DateTimeField(required=False)
lte = DateTimeField(required=False)


class DateRangeQueryField(DslQuerySerializer):
dsl_query = Range

gt = DateField(required=False)
lt = DateField(required=False)
gte = DateField(required=False)
lte = DateField(required=False)

def to_internal_value(self, data):
value = super(DateRangeQueryField, self).to_internal_value(data)

if 'gt' in value:
value['gt'] = datetime.combine(date=value['gt'], time=time.min)
if 'lt' in value:
value['lt'] = datetime.combine(date=value['lt'], time=time.max)
if 'gte' in value:
value['gte'] = datetime.combine(date=value['gte'], time=time.min)
if 'lte' in value:
value['lte'] = datetime.combine(date=value['lte'], time=time.max)

return value
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name='drf_util',
version='0.0.28',
version='0.0.29',
description='Django Rest Framework Utils',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
Loading