Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Fully merge Pydantic Field with SQLAlchemy Column constructor #436

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 229 additions & 90 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ipaddress
import uuid
import warnings
import weakref
from datetime import date, datetime, time, timedelta
from decimal import Decimal
Expand Down Expand Up @@ -38,8 +39,10 @@
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import FetchedValue, MetaData, SchemaItem
from sqlalchemy.sql.sqltypes import LargeBinary, Time
from sqlalchemy.sql.type_api import TypeEngine

from .sql.sqltypes import GUID, AutoString

Expand All @@ -57,35 +60,94 @@ def __dataclass_transform__(


class FieldInfo(PydanticFieldInfo):
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False)
nullable = kwargs.pop("nullable", Undefined)
foreign_key = kwargs.pop("foreign_key", Undefined)
unique = kwargs.pop("unique", False)
index = kwargs.pop("index", Undefined)
sa_column = kwargs.pop("sa_column", Undefined)
sa_column_args = kwargs.pop("sa_column_args", Undefined)
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
if sa_column is not Undefined:
if sa_column_args is not Undefined:
raise RuntimeError(
"Passing sa_column_args is not supported when "
"also passing a sa_column"
)
if sa_column_kwargs is not Undefined:
raise RuntimeError(
"Passing sa_column_kwargs is not supported when "
"also passing a sa_column"
)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.nullable = nullable
self.foreign_key = foreign_key
self.unique = unique
self.index = index
self.sa_column = sa_column
self.sa_column_args = sa_column_args
self.sa_column_kwargs = sa_column_kwargs

# In addition to the `PydanticFieldInfo` slots, set slots corresponding to parameters for the SQLAlchemy
# [Column](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column),
# along with any custom additions:
__slots__ = (
"name",
"type_",
"args",
"autoincrement",
# `default` omitted because that slot is defined on the base class
"doc",
"key",
"index",
"info",
"nullable",
"onupdate",
"primary_key",
"server_default",
"server_onupdate",
"quote",
"unique",
"system",
"comment",
"foreign_key", # custom parameter for easier foreign key setting
# For backwards compatibility: (!?)
"sa_column",
"sa_column_args",
"sa_column_kwargs",
)

# Defined here for static type checkers:
name: Union[str, UndefinedType]
type_: Union[TypeEngine, UndefinedType] # type: ignore[type-arg]
args: Sequence[SchemaItem]
autoincrement: Union[bool, str]
doc: Optional[str]
key: Union[str, UndefinedType]
index: Optional[bool]
info: Union[Dict[str, Any], UndefinedType]
nullable: Union[bool, UndefinedType]
onupdate: Any
primary_key: bool
server_default: Union[FetchedValue, str, TextClause, None]
server_onupdate: Optional[FetchedValue]
quote: Union[bool, None, UndefinedType]
unique: Optional[bool]
system: bool
comment: Optional[str]

foreign_key: Optional[str]
daniil-berg marked this conversation as resolved.
Show resolved Hide resolved

sa_column: Union[Column, UndefinedType] # type: ignore[type-arg]
sa_column_args: Sequence[Any]
sa_column_kwargs: Mapping[str, Any]

def __init__(self, **kwargs: Any) -> None:
# Split off all keyword-arguments corresponding to our new additional attributes:
new_kwargs = {param: kwargs.pop(param, Undefined) for param in self.__slots__}
# Pass the rest of the keyword-arguments to the Pydantic `FieldInfo.__init__`:
super().__init__(**kwargs)
# Set the other keyword-arguments as instance attributes:
for param, value in new_kwargs.items():
setattr(self, param, value)

def get_defined_column_kwargs(self) -> Dict[str, Any]:
"""
Returns a dictionary of keyword arguments for the SQLAlchemy `Column.__init__` method
derived from the corresponding attributes of the `FieldInfo` instance,
omitting all those that have been left undefined.
"""
special = {
"args",
"foreign_key",
"sa_column",
"sa_column_args",
"sa_column_kwargs",
}
kwargs = {}
for key in self.__slots__:
if key in special:
continue
value = getattr(self, key, Undefined)
if value is not Undefined:
kwargs[key] = value
default = get_field_info_default(self)
if default is not Undefined:
kwargs["default"] = default
return kwargs


class RelationshipInfo(Representation):
Expand Down Expand Up @@ -117,8 +179,9 @@ def __init__(


def Field(
default: Any = Undefined,
*,
*args: SchemaItem, # positional arguments for SQLAlchemy `Column.__init__`
default: Any = Undefined, # meaningful for both Pydantic and SQLAlchemy
# The following are specific to Pydantic:
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand All @@ -141,19 +204,78 @@ def Field(
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
# The following are specific to SQLAlchemy:
name: Optional[str] = None,
type_: Union[TypeEngine, UndefinedType] = Undefined, # type: ignore[type-arg]
autoincrement: Union[bool, str] = "auto",
doc: Optional[str] = None,
key: Union[str, UndefinedType] = Undefined, # `Column` default is `name`
index: Optional[bool] = None,
info: Union[Dict[str, Any], UndefinedType] = Undefined, # `Column` default is `{}`
nullable: Union[
bool, UndefinedType
] = Undefined, # `Column` default depends on `primary_key`
onupdate: Any = None,
primary_key: bool = False,
foreign_key: Optional[Any] = None,
unique: bool = False,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
server_default: Union[FetchedValue, str, TextClause, None] = None,
server_onupdate: Optional[FetchedValue] = None,
quote: Union[
bool, None, UndefinedType
] = Undefined, # `Column` default not (fully) defined
unique: Optional[bool] = None,
system: bool = False,
comment: Optional[str] = None,
foreign_key: Optional[str] = None,
daniil-berg marked this conversation as resolved.
Show resolved Hide resolved
# For backwards compatibility: (!?)
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore[type-arg]
sa_column_args: Sequence[Any] = (),
sa_column_kwargs: Optional[Mapping[str, Any]] = None,
# Extra:
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
) -> FieldInfo:
"""
Constructor for explicitly defining the attributes of a model field.

The resulting field information is used both for Pydantic model validation **and** for SQLAlchemy column definition.

The following parameters are passed to initialize the Pydantic `FieldInfo`
(see [`Field` docs](https://pydantic-docs.helpmanual.io/usage/schema/#field-customization)):
`default`, `default_factory`, `alias`, `title`, `description`, `exclude`, `include`, `const`, `gt`, `ge`,
`lt`, `le`, `multiple_of`, `min_items`, `max_items`, `min_length`, `max_length`, `allow_mutation`, `regex`.

These parameters are passed to initialize the SQLAlchemy
[`Column`](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column):
`*args`, `name`, `type_`, `autoincrement`, `doc`, `key`, `index`, `info`, `nullable`, `onupdate`, `primary_key`,
`server_default`, `server_onupdate`, `quote`, `unique`, `system`, `comment`.

If provided, the `default_factory` argument is passed as `default` to the `Column` constructor;
otherwise, if the `default` argument is provided, it is passed to the `Column` constructor.

Note:
The SQLAlchemy `Column` default for `type_` is actually `None`, but it makes more sense to leave it undefined,
unless an argument is passed explicitly. If someone explicitly wants to pass `None` to set the `NullType` for
whatever reason, they will be able to do so.
(see [`type_`](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column.params.type_))
"""
current_schema_extra = schema_extra or {}
# For backwards compatibility: (!?)
if sa_column is not Undefined:
warnings.warn(
"Specifying `sa_column` overrides all other column arguments",
DeprecationWarning,
)
if sa_column_args != ():
warnings.warn(
"Instead of `sa_column_args` use positional arguments",
DeprecationWarning,
)
if sa_column_kwargs is not None:
warnings.warn(
"`sa_column_kwargs` takes precedence over other keyword-arguments",
DeprecationWarning,
)
field_info = FieldInfo(
daniil-berg marked this conversation as resolved.
Show resolved Hide resolved
default,
default=default,
default_factory=default_factory,
alias=alias,
title=title,
Expand All @@ -172,14 +294,27 @@ def Field(
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
name=name,
type_=type_,
args=args,
autoincrement=autoincrement,
doc=doc,
key=key,
index=index,
info=info,
nullable=nullable,
onupdate=onupdate,
primary_key=primary_key,
foreign_key=foreign_key,
server_default=server_default,
server_onupdate=server_onupdate,
quote=quote,
unique=unique,
nullable=nullable,
index=index,
system=system,
comment=comment,
foreign_key=foreign_key,
sa_column=sa_column,
sa_column_args=sa_column_args,
sa_column_kwargs=sa_column_kwargs,
sa_column_kwargs=sa_column_kwargs or {},
**current_schema_extra,
)
field_info._validate()
Expand Down Expand Up @@ -414,47 +549,53 @@ def get_sqlalchemy_type(field: ModelField) -> Any:
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")


def get_column_from_field(field: ModelField) -> Column: # type: ignore
sa_column = getattr(field.field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
return sa_column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field.field_info, "primary_key", False)
index = getattr(field.field_info, "index", Undefined)
if index is Undefined:
index = False
nullable = not primary_key and _is_field_noneable(field)
# Override derived nullability if the nullable property is set explicitly
# on the field
if hasattr(field.field_info, "nullable"):
field_nullable = getattr(field.field_info, "nullable")
if field_nullable != Undefined:
nullable = field_nullable
args = []
foreign_key = getattr(field.field_info, "foreign_key", None)
unique = getattr(field.field_info, "unique", False)
if foreign_key:
args.append(ForeignKey(foreign_key))
def get_field_info_default(info: PydanticFieldInfo) -> Any:
"""Returns the `default_factory` if set, otherwise the `default` value."""
return info.default_factory if info.default_factory is not None else info.default


def get_column_from_pydantic_field(field: ModelField) -> Column: # type: ignore[type-arg]
"""Returns an SQLAlchemy `Column` instance derived from a regular Pydantic `ModelField`."""
kwargs = {
"primary_key": primary_key,
"nullable": nullable,
"index": index,
"unique": unique,
"type_": get_sqlalchemy_type(field),
"nullable": _is_field_noneable(field),
}
sa_default = Undefined
if field.field_info.default_factory:
sa_default = field.field_info.default_factory
elif field.field_info.default is not Undefined:
sa_default = field.field_info.default
if sa_default is not Undefined:
kwargs["default"] = sa_default
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
if sa_column_args is not Undefined:
args.extend(list(cast(Sequence[Any], sa_column_args)))
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
if sa_column_kwargs is not Undefined:
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
return Column(sa_type, *args, **kwargs) # type: ignore
default = get_field_info_default(field.field_info)
if default is not Undefined:
kwargs["default"] = default
return Column(**kwargs)


def get_column_from_field(field: ModelField) -> Column: # type: ignore[type-arg]
"""Returns an SQLAlchemy `Column` instance derived from an SQLModel field."""
if not isinstance(
field.field_info, FieldInfo
): # must be regular `PydanticFieldInfo`
return get_column_from_pydantic_field(field)
# We are dealing with the customized `FieldInfo` object:
field_info: FieldInfo = field.field_info
# The `sa_column` argument trumps everything: (for backwards compatibility)
if isinstance(field_info.sa_column, Column):
return field_info.sa_column
args: List[SchemaItem] = []
kwargs = field_info.get_defined_column_kwargs()
# Only if no column type was explicitly defined, do we derive it here:
kwargs.setdefault("type_", get_sqlalchemy_type(field))
# Only if nullability was not defined, do we infer it here:
kwargs.setdefault(
"nullable", not kwargs.get("primary_key", False) and _is_field_noneable(field)
)
# If a foreign key reference was explicitly named, construct the schema item here,
# and make it the first positional argument for the `Column`:
if field_info.foreign_key:
args.append(ForeignKey(field_info.foreign_key))
daniil-berg marked this conversation as resolved.
Show resolved Hide resolved
# All other positional column arguments are appended:
args.extend(field_info.args)
# Append `sa_column_args`: (for backwards compatibility)
args.extend(field_info.sa_column_args)
# Finally, let the `sa_column_kwargs` take precedence: (for backwards compatibility)
kwargs.update(field_info.sa_column_kwargs)
return Column(*args, **kwargs)


class_registry = weakref.WeakValueDictionary() # type: ignore
Expand Down Expand Up @@ -647,9 +788,7 @@ def __tablename__(cls) -> str:


def _is_field_noneable(field: ModelField) -> bool:
if not field.required:
# Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
return field.allow_none and (
field.shape != SHAPE_SINGLETON or not field.sub_fields
)
return False
if field.required:
return False
# Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
return field.allow_none and (field.shape != SHAPE_SINGLETON or not field.sub_fields)
Loading