Skip to content

Commit

Permalink
Add unit testing for base fields
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Jul 31, 2023
1 parent 5b4d4af commit 589e685
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ build-docs: ## Runs the local docs

.PHONY: test
test: ## Runs the tests
scripts/check && pytest $(TESTONLY) --disable-pytest-warnings -s -vv && scripts/clean
scripts/check && pytest $(TESTONLY) --disable-pytest-warnings -s -vv

.PHONY: requirements
requirements: ## Install requirements for development
Expand Down
13 changes: 11 additions & 2 deletions edgy/core/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
self.multiple_of: Optional[Union[int, float, decimal.Decimal]] = kwargs.pop(
"multiple_of", None
)

# Constraints
self.contraints: Constraint = kwargs.pop("constraints", None)

Expand All @@ -70,7 +71,6 @@ def __init__(
super().__init__(
default=default,
alias=self.alias,
required=self.null,
title=title,
description=description,
min_length=self.min_length,
Expand All @@ -82,10 +82,19 @@ def __init__(
multiple_of=self.multiple_of,
max_digits=self.max_digits,
decimal_places=self.decimal_places,
regex=self.regex,
pattern=self.regex,
**kwargs,
)

def is_required(self) -> bool:
"""Check if the argument is required.
Returns:
`True` if the argument is required, `False` otherwise.
"""
required = False if self.null else True
return required

def get_alias(self) -> str:
"""
Used to translate the model column names into database column tables.
Expand Down
24 changes: 4 additions & 20 deletions edgy/core/db/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class FieldFactory:
"""The base for all model fields to be used with EdgeDB"""

_bases = (BaseField,)
_property: bool = False
_link: bool = False
_type: Any = None

def __new__(cls, *args: Any, **kwargs: Any) -> BaseField: # type: ignore
Expand All @@ -40,13 +38,10 @@ def __new__(cls, *args: Any, **kwargs: Any) -> BaseField: # type: ignore
server_default = kwargs.pop("server_default", None)
server_onupdate = kwargs.pop("server_onupdate", None)
field_type = cls._type
is_property = cls._property
is_link = cls._link

namespace = dict(
__type__=field_type,
__property__=is_property,
__link__=is_link,
annotation=field_type,
name=name,
primary_key=primary_key,
default=default,
Expand Down Expand Up @@ -88,7 +83,6 @@ class CharField(FieldFactory, str):
"""String field representation that constructs the Field class and populates the values"""

_type = str
_property: bool = True

def __new__( # type: ignore
cls,
Expand Down Expand Up @@ -120,7 +114,6 @@ class TextField(FieldFactory, str):
"""String representation of a text field which means no max_length required"""

_type = str
_property: bool = True

def __new__(cls, **kwargs: Any) -> BaseField: # type: ignore
kwargs = {
Expand Down Expand Up @@ -159,7 +152,6 @@ class IntegerField(Number, int):
"""

_type = int
_property: bool = True

def __new__( # type: ignore
cls,
Expand Down Expand Up @@ -190,7 +182,6 @@ class FloatField(Number, float):
"""Representation of a int32 and int64"""

_type = float
_property: bool = True

def __new__( # type: ignore
cls,
Expand Down Expand Up @@ -231,7 +222,6 @@ def get_column_type(cls, **kwargs: Any) -> Any:

class DecimalField(Number, decimal.Decimal):
_type = decimal.Decimal
_property: bool = True

def __new__( # type: ignore
cls,
Expand Down Expand Up @@ -278,7 +268,6 @@ class BooleanField(FieldFactory, int):
"""Representation of a boolean"""

_type = bool
_property: bool = True

def __new__( # type: ignore
cls,
Expand Down Expand Up @@ -321,7 +310,6 @@ class DateTimeField(AutoNowMixin, datetime.datetime):
"""Representation of a datetime field"""

_type = datetime.datetime
_property: bool = True

def __new__( # type: ignore
cls,
Expand All @@ -348,7 +336,6 @@ class DateField(AutoNowMixin, datetime.date):
"""Representation of a date field"""

_type = datetime.date
_property: bool = True

def __new__( # type: ignore
cls,
Expand All @@ -375,7 +362,6 @@ class TimeField(FieldFactory, datetime.time):
"""Representation of a time field"""

_type = datetime.time
_property: bool = True

def __new__(cls, **kwargs: Any) -> BaseField: # type: ignore
kwargs = {
Expand All @@ -393,7 +379,6 @@ class JSONField(FieldFactory, pydantic.Json): # type: ignore
"""Representation of a JSONField"""

_type = pydantic.Json
_property: bool = True

@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
Expand All @@ -404,7 +389,6 @@ class BinaryField(FieldFactory, bytes):
"""Representation of a binary"""

_type = bytes
_property: bool = True

def __new__(cls, *, max_length: Optional[int] = 0, **kwargs: Any) -> BaseField: # type: ignore
kwargs = {
Expand All @@ -428,7 +412,6 @@ class UUIDField(FieldFactory, uuid.UUID):
"""Representation of a uuid"""

_type = uuid.UUID
_property: bool = True

def __new__(cls, **kwargs: Any) -> BaseField: # type: ignore
kwargs = {
Expand All @@ -447,10 +430,11 @@ class ChoiceField(FieldFactory):
"""Representation of an Enum"""

_type = enum.Enum
_property: bool = True

def __new__( # type: ignore
cls, choices: Sequence[Union[Tuple[str, str], Tuple[str, int]]], **kwargs: Any
cls,
choices: Optional[Sequence[Union[Tuple[str, str], Tuple[str, int]]]] = None,
**kwargs: Any,
) -> BaseField:
kwargs = {
**kwargs,
Expand Down
82 changes: 82 additions & 0 deletions tests/core/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import datetime
import decimal
import enum
import uuid

import pydantic
import pytest
import sqlalchemy

Expand All @@ -26,6 +28,11 @@
from edgy.exceptions import FieldDefinitionError


class Choices(str, enum.Enum):
ACTIVE = "active"
INACTIVE = "inactive"


def test_column_type():
assert isinstance(CharField.get_column_type(), sqlalchemy.String)
assert isinstance(TextField.get_column_type(), sqlalchemy.String)
Expand All @@ -44,13 +51,83 @@ def test_column_type():
assert isinstance(ChoiceField.get_column_type(), sqlalchemy.Enum)


@pytest.mark.parametrize(
"field,annotation",
[
(CharField(max_length=255), str),
(TextField(), str),
(FloatField(), float),
(BooleanField(), bool),
(DateTimeField(auto_now=True), datetime.datetime),
(DateField(auto_now=True), datetime.date),
(TimeField(), datetime.time),
(JSONField(), pydantic.Json),
(BinaryField(max_length=255), bytes),
(IntegerField(), int),
(BigIntegerField(), int),
(SmallIntegerField(), int),
(DecimalField(max_digits=20, precision=2), decimal.Decimal),
(ChoiceField(choices=Choices), enum.Enum),
],
)
def test_field_annotation(field, annotation):
assert field.annotation == annotation


@pytest.mark.parametrize(
"field,is_required",
[
(CharField(max_length=255, null=False), True),
(TextField(null=False), True),
(FloatField(null=False), True),
(DateTimeField(null=False), True),
(DateField(null=False), True),
(TimeField(null=False), True),
(JSONField(null=False), True),
(BinaryField(max_length=255, null=False), True),
(IntegerField(null=False), True),
(BigIntegerField(null=False), True),
(SmallIntegerField(null=False), True),
(DecimalField(max_digits=20, precision=2, null=False), True),
(ChoiceField(choices=Choices, null=False), True),
],
)
def test_field_required(field, is_required):
assert field.is_required() == is_required
assert field.null is False


@pytest.mark.parametrize(
"field,is_required",
[
(CharField(max_length=255, null=True), False),
(TextField(null=True), False),
(FloatField(null=True), False),
(DateTimeField(null=True), False),
(DateField(null=True), False),
(TimeField(null=True), False),
(JSONField(null=True), False),
(BinaryField(max_length=255, null=True), False),
(IntegerField(null=True), False),
(BigIntegerField(null=True), False),
(SmallIntegerField(null=True), False),
(DecimalField(max_digits=20, precision=2, null=True), False),
(ChoiceField(choices=Choices, null=True), False),
],
)
def test_field_is_not_required(field, is_required):
assert field.is_required() == is_required
assert field.null is True


def test_can_create_string_field():
field = CharField(min_length=5, max_length=10, null=True)

assert isinstance(field, BaseField)
assert field.min_length == 5
assert field.max_length == 10
assert field.null is True
assert not field.is_required()


def test_raises_field_definition_error_on_string_creation():
Expand Down Expand Up @@ -189,3 +266,8 @@ class StatusChoice(str, enum.Enum):

assert isinstance(field, BaseField)
assert len(field.choices) == 2


def test_raise_exception_choice_field():
with pytest.raises(FieldDefinitionError):
ChoiceField()

0 comments on commit 589e685

Please sign in to comment.