diff --git a/sqlmodel/main.py b/sqlmodel/main.py index d95c498507..5420004e6a 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,5 +1,6 @@ import ipaddress import uuid +import warnings import weakref from datetime import date, datetime, time, timedelta from decimal import Decimal @@ -38,8 +39,10 @@ from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented -from sqlalchemy.sql.schema import MetaData +from sqlalchemy.sql.elements import TextClause +from sqlalchemy.sql.schema import FetchedValue, MetaData, SchemaItem from sqlalchemy.sql.sqltypes import LargeBinary, Time +from sqlalchemy.sql.type_api import TypeEngine from .sql.sqltypes import GUID, AutoString @@ -57,35 +60,94 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): - def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: - primary_key = kwargs.pop("primary_key", False) - nullable = kwargs.pop("nullable", Undefined) - foreign_key = kwargs.pop("foreign_key", Undefined) - unique = kwargs.pop("unique", False) - index = kwargs.pop("index", Undefined) - sa_column = kwargs.pop("sa_column", Undefined) - sa_column_args = kwargs.pop("sa_column_args", Undefined) - sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) - if sa_column is not Undefined: - if sa_column_args is not Undefined: - raise RuntimeError( - "Passing sa_column_args is not supported when " - "also passing a sa_column" - ) - if sa_column_kwargs is not Undefined: - raise RuntimeError( - "Passing sa_column_kwargs is not supported when " - "also passing a sa_column" - ) - super().__init__(default=default, **kwargs) - self.primary_key = primary_key - self.nullable = nullable - self.foreign_key = foreign_key - self.unique = unique - self.index = index - self.sa_column = sa_column - self.sa_column_args = sa_column_args - self.sa_column_kwargs = sa_column_kwargs + + # In addition to the `PydanticFieldInfo` slots, set slots corresponding to parameters for the SQLAlchemy + # [Column](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column), + # along with any custom additions: + __slots__ = ( + "name", + "type_", + "args", + "autoincrement", + # `default` omitted because that slot is defined on the base class + "doc", + "key", + "index", + "info", + "nullable", + "onupdate", + "primary_key", + "server_default", + "server_onupdate", + "quote", + "unique", + "system", + "comment", + "foreign_key", # custom parameter for easier foreign key setting + # For backwards compatibility: (!?) + "sa_column", + "sa_column_args", + "sa_column_kwargs", + ) + + # Defined here for static type checkers: + name: Union[str, UndefinedType] + type_: Union[TypeEngine, UndefinedType] # type: ignore[type-arg] + args: Sequence[SchemaItem] + autoincrement: Union[bool, str] + doc: Optional[str] + key: Union[str, UndefinedType] + index: Optional[bool] + info: Union[Dict[str, Any], UndefinedType] + nullable: Union[bool, UndefinedType] + onupdate: Any + primary_key: bool + server_default: Union[FetchedValue, str, TextClause, None] + server_onupdate: Optional[FetchedValue] + quote: Union[bool, None, UndefinedType] + unique: Optional[bool] + system: bool + comment: Optional[str] + + foreign_key: Optional[str] + + sa_column: Union[Column, UndefinedType] # type: ignore[type-arg] + sa_column_args: Sequence[Any] + sa_column_kwargs: Mapping[str, Any] + + def __init__(self, **kwargs: Any) -> None: + # Split off all keyword-arguments corresponding to our new additional attributes: + new_kwargs = {param: kwargs.pop(param, Undefined) for param in self.__slots__} + # Pass the rest of the keyword-arguments to the Pydantic `FieldInfo.__init__`: + super().__init__(**kwargs) + # Set the other keyword-arguments as instance attributes: + for param, value in new_kwargs.items(): + setattr(self, param, value) + + def get_defined_column_kwargs(self) -> Dict[str, Any]: + """ + Returns a dictionary of keyword arguments for the SQLAlchemy `Column.__init__` method + derived from the corresponding attributes of the `FieldInfo` instance, + omitting all those that have been left undefined. + """ + special = { + "args", + "foreign_key", + "sa_column", + "sa_column_args", + "sa_column_kwargs", + } + kwargs = {} + for key in self.__slots__: + if key in special: + continue + value = getattr(self, key, Undefined) + if value is not Undefined: + kwargs[key] = value + default = get_field_info_default(self) + if default is not Undefined: + kwargs["default"] = default + return kwargs class RelationshipInfo(Representation): @@ -117,8 +179,9 @@ def __init__( def Field( - default: Any = Undefined, - *, + *args: SchemaItem, # positional arguments for SQLAlchemy `Column.__init__` + default: Any = Undefined, # meaningful for both Pydantic and SQLAlchemy + # The following are specific to Pydantic: default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, title: Optional[str] = None, @@ -141,19 +204,78 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, + # The following are specific to SQLAlchemy: + name: Optional[str] = None, + type_: Union[TypeEngine, UndefinedType] = Undefined, # type: ignore[type-arg] + autoincrement: Union[bool, str] = "auto", + doc: Optional[str] = None, + key: Union[str, UndefinedType] = Undefined, # `Column` default is `name` + index: Optional[bool] = None, + info: Union[Dict[str, Any], UndefinedType] = Undefined, # `Column` default is `{}` + nullable: Union[ + bool, UndefinedType + ] = Undefined, # `Column` default depends on `primary_key` + onupdate: Any = None, primary_key: bool = False, - foreign_key: Optional[Any] = None, - unique: bool = False, - nullable: Union[bool, UndefinedType] = Undefined, - index: Union[bool, UndefinedType] = Undefined, - sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore - sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, - sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + server_default: Union[FetchedValue, str, TextClause, None] = None, + server_onupdate: Optional[FetchedValue] = None, + quote: Union[ + bool, None, UndefinedType + ] = Undefined, # `Column` default not (fully) defined + unique: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + foreign_key: Optional[str] = None, + # For backwards compatibility: (!?) + sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore[type-arg] + sa_column_args: Sequence[Any] = (), + sa_column_kwargs: Optional[Mapping[str, Any]] = None, + # Extra: schema_extra: Optional[Dict[str, Any]] = None, -) -> Any: +) -> FieldInfo: + """ + Constructor for explicitly defining the attributes of a model field. + + The resulting field information is used both for Pydantic model validation **and** for SQLAlchemy column definition. + + The following parameters are passed to initialize the Pydantic `FieldInfo` + (see [`Field` docs](https://pydantic-docs.helpmanual.io/usage/schema/#field-customization)): + `default`, `default_factory`, `alias`, `title`, `description`, `exclude`, `include`, `const`, `gt`, `ge`, + `lt`, `le`, `multiple_of`, `min_items`, `max_items`, `min_length`, `max_length`, `allow_mutation`, `regex`. + + These parameters are passed to initialize the SQLAlchemy + [`Column`](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column): + `*args`, `name`, `type_`, `autoincrement`, `doc`, `key`, `index`, `info`, `nullable`, `onupdate`, `primary_key`, + `server_default`, `server_onupdate`, `quote`, `unique`, `system`, `comment`. + + If provided, the `default_factory` argument is passed as `default` to the `Column` constructor; + otherwise, if the `default` argument is provided, it is passed to the `Column` constructor. + + Note: + The SQLAlchemy `Column` default for `type_` is actually `None`, but it makes more sense to leave it undefined, + unless an argument is passed explicitly. If someone explicitly wants to pass `None` to set the `NullType` for + whatever reason, they will be able to do so. + (see [`type_`](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column.params.type_)) + """ current_schema_extra = schema_extra or {} + # For backwards compatibility: (!?) + if sa_column is not Undefined: + warnings.warn( + "Specifying `sa_column` overrides all other column arguments", + DeprecationWarning, + ) + if sa_column_args != (): + warnings.warn( + "Instead of `sa_column_args` use positional arguments", + DeprecationWarning, + ) + if sa_column_kwargs is not None: + warnings.warn( + "`sa_column_kwargs` takes precedence over other keyword-arguments", + DeprecationWarning, + ) field_info = FieldInfo( - default, + default=default, default_factory=default_factory, alias=alias, title=title, @@ -172,14 +294,27 @@ def Field( max_length=max_length, allow_mutation=allow_mutation, regex=regex, + name=name, + type_=type_, + args=args, + autoincrement=autoincrement, + doc=doc, + key=key, + index=index, + info=info, + nullable=nullable, + onupdate=onupdate, primary_key=primary_key, - foreign_key=foreign_key, + server_default=server_default, + server_onupdate=server_onupdate, + quote=quote, unique=unique, - nullable=nullable, - index=index, + system=system, + comment=comment, + foreign_key=foreign_key, sa_column=sa_column, sa_column_args=sa_column_args, - sa_column_kwargs=sa_column_kwargs, + sa_column_kwargs=sa_column_kwargs or {}, **current_schema_extra, ) field_info._validate() @@ -414,47 +549,53 @@ def get_sqlalchemy_type(field: ModelField) -> Any: raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") -def get_column_from_field(field: ModelField) -> Column: # type: ignore - sa_column = getattr(field.field_info, "sa_column", Undefined) - if isinstance(sa_column, Column): - return sa_column - sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field.field_info, "primary_key", False) - index = getattr(field.field_info, "index", Undefined) - if index is Undefined: - index = False - nullable = not primary_key and _is_field_noneable(field) - # Override derived nullability if the nullable property is set explicitly - # on the field - if hasattr(field.field_info, "nullable"): - field_nullable = getattr(field.field_info, "nullable") - if field_nullable != Undefined: - nullable = field_nullable - args = [] - foreign_key = getattr(field.field_info, "foreign_key", None) - unique = getattr(field.field_info, "unique", False) - if foreign_key: - args.append(ForeignKey(foreign_key)) +def get_field_info_default(info: PydanticFieldInfo) -> Any: + """Returns the `default_factory` if set, otherwise the `default` value.""" + return info.default_factory if info.default_factory is not None else info.default + + +def get_column_from_pydantic_field(field: ModelField) -> Column: # type: ignore[type-arg] + """Returns an SQLAlchemy `Column` instance derived from a regular Pydantic `ModelField`.""" kwargs = { - "primary_key": primary_key, - "nullable": nullable, - "index": index, - "unique": unique, + "type_": get_sqlalchemy_type(field), + "nullable": _is_field_noneable(field), } - sa_default = Undefined - if field.field_info.default_factory: - sa_default = field.field_info.default_factory - elif field.field_info.default is not Undefined: - sa_default = field.field_info.default - if sa_default is not Undefined: - kwargs["default"] = sa_default - sa_column_args = getattr(field.field_info, "sa_column_args", Undefined) - if sa_column_args is not Undefined: - args.extend(list(cast(Sequence[Any], sa_column_args))) - sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined) - if sa_column_kwargs is not Undefined: - kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) - return Column(sa_type, *args, **kwargs) # type: ignore + default = get_field_info_default(field.field_info) + if default is not Undefined: + kwargs["default"] = default + return Column(**kwargs) + + +def get_column_from_field(field: ModelField) -> Column: # type: ignore[type-arg] + """Returns an SQLAlchemy `Column` instance derived from an SQLModel field.""" + if not isinstance( + field.field_info, FieldInfo + ): # must be regular `PydanticFieldInfo` + return get_column_from_pydantic_field(field) + # We are dealing with the customized `FieldInfo` object: + field_info: FieldInfo = field.field_info + # The `sa_column` argument trumps everything: (for backwards compatibility) + if isinstance(field_info.sa_column, Column): + return field_info.sa_column + args: List[SchemaItem] = [] + kwargs = field_info.get_defined_column_kwargs() + # Only if no column type was explicitly defined, do we derive it here: + kwargs.setdefault("type_", get_sqlalchemy_type(field)) + # Only if nullability was not defined, do we infer it here: + kwargs.setdefault( + "nullable", not kwargs.get("primary_key", False) and _is_field_noneable(field) + ) + # If a foreign key reference was explicitly named, construct the schema item here, + # and make it the first positional argument for the `Column`: + if field_info.foreign_key: + args.append(ForeignKey(field_info.foreign_key)) + # All other positional column arguments are appended: + args.extend(field_info.args) + # Append `sa_column_args`: (for backwards compatibility) + args.extend(field_info.sa_column_args) + # Finally, let the `sa_column_kwargs` take precedence: (for backwards compatibility) + kwargs.update(field_info.sa_column_kwargs) + return Column(*args, **kwargs) class_registry = weakref.WeakValueDictionary() # type: ignore @@ -647,9 +788,7 @@ def __tablename__(cls) -> str: def _is_field_noneable(field: ModelField) -> bool: - if not field.required: - # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) - return field.allow_none and ( - field.shape != SHAPE_SINGLETON or not field.sub_fields - ) - return False + if field.required: + return False + # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) + return field.allow_none and (field.shape != SHAPE_SINGLETON or not field.sub_fields) diff --git a/tests/test_sa_column.py b/tests/test_sa_column.py new file mode 100644 index 0000000000..f4463c0988 --- /dev/null +++ b/tests/test_sa_column.py @@ -0,0 +1,33 @@ +"""These test cases should become obsolete once `sa_column`-parameters are dropped from `Field`.""" + +import pytest +from pydantic.config import BaseConfig +from pydantic.fields import ModelField +from sqlalchemy.sql.schema import CheckConstraint, Column, ForeignKey +from sqlmodel.main import Field, FieldInfo, get_column_from_field + + +def test_sa_column_params_raise_warnings(): + with pytest.warns(DeprecationWarning): + Field(sa_column=Column()) + with pytest.warns(DeprecationWarning): + Field(sa_column_args=[ForeignKey("foo.id"), CheckConstraint(">1")]) + with pytest.warns(DeprecationWarning): + Field(sa_column_kwargs={"name": "foo"}) + + +def test_sa_column_overrides_other_params(): + col = Column() + field = ModelField( + name="foo", + type_=str, + class_validators=None, + model_config=BaseConfig, + field_info=FieldInfo( + index=True, # should be ignored + sa_column=col, + ), + ) + output = get_column_from_field(field) + assert output is col + assert output.index is None