Skip to content

Commit c93b42e

Browse files
committed
✨ Fully merge Pydantic Field with SQLAlchemy Column constructor;
allow passing all `Column` arguments directly to `Field`; make `default` a keyword-only argument for `Field`
1 parent 75ce455 commit c93b42e

File tree

1 file changed

+226
-92
lines changed

1 file changed

+226
-92
lines changed

sqlmodel/main.py

Lines changed: 226 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ipaddress
22
import uuid
3+
import warnings
34
import weakref
45
from datetime import date, datetime, time, timedelta
56
from decimal import Decimal
@@ -38,8 +39,10 @@
3839
from sqlalchemy.orm.attributes import set_attribute
3940
from sqlalchemy.orm.decl_api import DeclarativeMeta
4041
from sqlalchemy.orm.instrumentation import is_instrumented
41-
from sqlalchemy.sql.schema import MetaData
42+
from sqlalchemy.sql.elements import TextClause
43+
from sqlalchemy.sql.schema import FetchedValue, MetaData, SchemaItem
4244
from sqlalchemy.sql.sqltypes import LargeBinary, Time
45+
from sqlalchemy.sql.type_api import TypeEngine
4346

4447
from .sql.sqltypes import GUID, AutoString
4548

@@ -57,35 +60,94 @@ def __dataclass_transform__(
5760

5861

5962
class FieldInfo(PydanticFieldInfo):
60-
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
61-
primary_key = kwargs.pop("primary_key", False)
62-
nullable = kwargs.pop("nullable", Undefined)
63-
foreign_key = kwargs.pop("foreign_key", Undefined)
64-
unique = kwargs.pop("unique", False)
65-
index = kwargs.pop("index", Undefined)
66-
sa_column = kwargs.pop("sa_column", Undefined)
67-
sa_column_args = kwargs.pop("sa_column_args", Undefined)
68-
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
69-
if sa_column is not Undefined:
70-
if sa_column_args is not Undefined:
71-
raise RuntimeError(
72-
"Passing sa_column_args is not supported when "
73-
"also passing a sa_column"
74-
)
75-
if sa_column_kwargs is not Undefined:
76-
raise RuntimeError(
77-
"Passing sa_column_kwargs is not supported when "
78-
"also passing a sa_column"
79-
)
80-
super().__init__(default=default, **kwargs)
81-
self.primary_key = primary_key
82-
self.nullable = nullable
83-
self.foreign_key = foreign_key
84-
self.unique = unique
85-
self.index = index
86-
self.sa_column = sa_column
87-
self.sa_column_args = sa_column_args
88-
self.sa_column_kwargs = sa_column_kwargs
63+
64+
# In addition to the `PydanticFieldInfo` slots, set slots corresponding to parameters for the SQLAlchemy
65+
# [Column](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column),
66+
# along with any custom additions:
67+
__slots__ = (
68+
"name",
69+
"type_",
70+
"args",
71+
"autoincrement",
72+
# `default` omitted because that slot is defined on the base class
73+
"doc",
74+
"key",
75+
"index",
76+
"info",
77+
"nullable",
78+
"onupdate",
79+
"primary_key",
80+
"server_default",
81+
"server_onupdate",
82+
"quote",
83+
"unique",
84+
"system",
85+
"comment",
86+
"foreign_key", # custom parameter for easier foreign key setting
87+
# For backwards compatibility: (!?)
88+
"sa_column",
89+
"sa_column_args",
90+
"sa_column_kwargs",
91+
)
92+
93+
# Defined here for static type checkers:
94+
name: Union[str, UndefinedType]
95+
type_: Union[TypeEngine, UndefinedType] # type: ignore[type-arg]
96+
args: Sequence[SchemaItem]
97+
autoincrement: Union[bool, str]
98+
doc: Optional[str]
99+
key: Union[str, UndefinedType]
100+
index: Optional[bool]
101+
info: Union[Dict[str, Any], UndefinedType]
102+
nullable: Union[bool, UndefinedType]
103+
onupdate: Any
104+
primary_key: bool
105+
server_default: Union[FetchedValue, str, TextClause, None]
106+
server_onupdate: Optional[FetchedValue]
107+
quote: Union[bool, None, UndefinedType]
108+
unique: Optional[bool]
109+
system: bool
110+
comment: Optional[str]
111+
112+
foreign_key: Optional[str]
113+
114+
sa_column: Union[Column, UndefinedType] # type: ignore[type-arg]
115+
sa_column_args: Sequence[Any]
116+
sa_column_kwargs: Mapping[str, Any]
117+
118+
def __init__(self, **kwargs: Any) -> None:
119+
# Split off all keyword-arguments corresponding to our new additional attributes:
120+
new_kwargs = {param: kwargs.pop(param, Undefined) for param in self.__slots__}
121+
# Pass the rest of the keyword-arguments to the Pydantic `FieldInfo.__init__`:
122+
super().__init__(**kwargs)
123+
# Set the other keyword-arguments as instance attributes:
124+
for param, value in new_kwargs.items():
125+
setattr(self, param, value)
126+
127+
def get_defined_column_kwargs(self) -> Dict[str, Any]:
128+
"""
129+
Returns a dictionary of keyword arguments for the SQLAlchemy `Column.__init__` method
130+
derived from the corresponding attributes of the `FieldInfo` instance,
131+
omitting all those that have been left undefined.
132+
"""
133+
special = {
134+
"args",
135+
"foreign_key",
136+
"sa_column",
137+
"sa_column_args",
138+
"sa_column_kwargs",
139+
}
140+
kwargs = {}
141+
for key in self.__slots__:
142+
if key in special:
143+
continue
144+
value = getattr(self, key, Undefined)
145+
if value is not Undefined:
146+
kwargs[key] = value
147+
default = get_field_info_default(self)
148+
if default is not Undefined:
149+
kwargs["default"] = default
150+
return kwargs
89151

90152

91153
class RelationshipInfo(Representation):
@@ -117,8 +179,9 @@ def __init__(
117179

118180

119181
def Field(
120-
default: Any = Undefined,
121-
*,
182+
*args: SchemaItem, # positional arguments for SQLAlchemy `Column.__init__`
183+
default: Any = Undefined, # meaningful for both Pydantic and SQLAlchemy
184+
# The following are specific to Pydantic:
122185
default_factory: Optional[NoArgAnyCallable] = None,
123186
alias: Optional[str] = None,
124187
title: Optional[str] = None,
@@ -141,19 +204,78 @@ def Field(
141204
max_length: Optional[int] = None,
142205
allow_mutation: bool = True,
143206
regex: Optional[str] = None,
207+
# The following are specific to SQLAlchemy:
208+
name: Optional[str] = None,
209+
type_: Union[TypeEngine, UndefinedType] = Undefined, # type: ignore[type-arg]
210+
autoincrement: Union[bool, str] = "auto",
211+
doc: Optional[str] = None,
212+
key: Union[str, UndefinedType] = Undefined, # `Column` default is `name`
213+
index: Optional[bool] = None,
214+
info: Union[Dict[str, Any], UndefinedType] = Undefined, # `Column` default is `{}`
215+
nullable: Union[
216+
bool, UndefinedType
217+
] = Undefined, # `Column` default depends on `primary_key`
218+
onupdate: Any = None,
144219
primary_key: bool = False,
145-
foreign_key: Optional[Any] = None,
146-
unique: bool = False,
147-
nullable: Union[bool, UndefinedType] = Undefined,
148-
index: Union[bool, UndefinedType] = Undefined,
149-
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
150-
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
151-
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
220+
server_default: Union[FetchedValue, str, TextClause, None] = None,
221+
server_onupdate: Optional[FetchedValue] = None,
222+
quote: Union[
223+
bool, None, UndefinedType
224+
] = Undefined, # `Column` default not (fully) defined
225+
unique: Optional[bool] = None,
226+
system: bool = False,
227+
comment: Optional[str] = None,
228+
foreign_key: Optional[str] = None,
229+
# For backwards compatibility: (!?)
230+
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore[type-arg]
231+
sa_column_args: Sequence[Any] = (),
232+
sa_column_kwargs: Optional[Mapping[str, Any]] = None,
233+
# Extra:
152234
schema_extra: Optional[Dict[str, Any]] = None,
153-
) -> Any:
235+
) -> FieldInfo:
236+
"""
237+
Constructor for explicitly defining the attributes of a model field.
238+
239+
The resulting field information is used both for Pydantic model validation **and** for SQLAlchemy column definition.
240+
241+
The following parameters are passed to initialize the Pydantic `FieldInfo`
242+
(see [`Field` docs](https://pydantic-docs.helpmanual.io/usage/schema/#field-customization)):
243+
`default`, `default_factory`, `alias`, `title`, `description`, `exclude`, `include`, `const`, `gt`, `ge`,
244+
`lt`, `le`, `multiple_of`, `min_items`, `max_items`, `min_length`, `max_length`, `allow_mutation`, `regex`.
245+
246+
These parameters are passed to initialize the SQLAlchemy
247+
[`Column`](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column):
248+
`*args`, `name`, `type_`, `autoincrement`, `doc`, `key`, `index`, `info`, `nullable`, `onupdate`, `primary_key`,
249+
`server_default`, `server_onupdate`, `quote`, `unique`, `system`, `comment`.
250+
251+
If provided, the `default_factory` argument is passed as `default` to the `Column` constructor;
252+
otherwise, if the `default` argument is provided, it is passed to the `Column` constructor.
253+
254+
Note:
255+
The SQLAlchemy `Column` default for `type_` is actually `None`, but it makes more sense to leave it undefined,
256+
unless an argument is passed explicitly. If someone explicitly wants to pass `None` to set the `NullType` for
257+
whatever reason, they will be able to do so.
258+
(see [`type_`](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column.params.type_))
259+
"""
154260
current_schema_extra = schema_extra or {}
261+
# For backwards compatibility: (!?)
262+
if sa_column is not Undefined:
263+
warnings.warn(
264+
"Specifying `sa_column` overrides all other column arguments",
265+
DeprecationWarning,
266+
)
267+
if sa_column_args != ():
268+
warnings.warn(
269+
"Instead of `sa_column_args` use positional arguments",
270+
DeprecationWarning,
271+
)
272+
if sa_column_kwargs is not None:
273+
warnings.warn(
274+
"`sa_column_kwargs` takes precedence over other keyword-arguments",
275+
DeprecationWarning,
276+
)
155277
field_info = FieldInfo(
156-
default,
278+
default=default,
157279
default_factory=default_factory,
158280
alias=alias,
159281
title=title,
@@ -172,14 +294,27 @@ def Field(
172294
max_length=max_length,
173295
allow_mutation=allow_mutation,
174296
regex=regex,
297+
name=name,
298+
type_=type_,
299+
args=args,
300+
autoincrement=autoincrement,
301+
doc=doc,
302+
key=key,
303+
index=index,
304+
info=info,
305+
nullable=nullable,
306+
onupdate=onupdate,
175307
primary_key=primary_key,
176-
foreign_key=foreign_key,
308+
server_default=server_default,
309+
server_onupdate=server_onupdate,
310+
quote=quote,
177311
unique=unique,
178-
nullable=nullable,
179-
index=index,
312+
system=system,
313+
comment=comment,
314+
foreign_key=foreign_key,
180315
sa_column=sa_column,
181316
sa_column_args=sa_column_args,
182-
sa_column_kwargs=sa_column_kwargs,
317+
sa_column_kwargs=sa_column_kwargs or {},
183318
**current_schema_extra,
184319
)
185320
field_info._validate()
@@ -414,47 +549,48 @@ def get_sqlachemy_type(field: ModelField) -> Any:
414549
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
415550

416551

417-
def get_column_from_field(field: ModelField) -> Column: # type: ignore
418-
sa_column = getattr(field.field_info, "sa_column", Undefined)
419-
if isinstance(sa_column, Column):
420-
return sa_column
421-
sa_type = get_sqlachemy_type(field)
422-
primary_key = getattr(field.field_info, "primary_key", False)
423-
index = getattr(field.field_info, "index", Undefined)
424-
if index is Undefined:
425-
index = False
426-
nullable = not primary_key and _is_field_noneable(field)
427-
# Override derived nullability if the nullable property is set explicitly
428-
# on the field
429-
if hasattr(field.field_info, "nullable"):
430-
field_nullable = getattr(field.field_info, "nullable")
431-
if field_nullable != Undefined:
432-
nullable = field_nullable
433-
args = []
434-
foreign_key = getattr(field.field_info, "foreign_key", None)
435-
unique = getattr(field.field_info, "unique", False)
436-
if foreign_key:
437-
args.append(ForeignKey(foreign_key))
438-
kwargs = {
439-
"primary_key": primary_key,
440-
"nullable": nullable,
441-
"index": index,
442-
"unique": unique,
443-
}
444-
sa_default = Undefined
445-
if field.field_info.default_factory:
446-
sa_default = field.field_info.default_factory
447-
elif field.field_info.default is not Undefined:
448-
sa_default = field.field_info.default
449-
if sa_default is not Undefined:
450-
kwargs["default"] = sa_default
451-
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
452-
if sa_column_args is not Undefined:
453-
args.extend(list(cast(Sequence[Any], sa_column_args)))
454-
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
455-
if sa_column_kwargs is not Undefined:
456-
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
457-
return Column(sa_type, *args, **kwargs) # type: ignore
552+
def get_field_info_default(info: PydanticFieldInfo) -> Any:
553+
"""Returns the `default_factory` if set, otherwise the `default` value."""
554+
return info.default_factory if info.default_factory is not None else info.default
555+
556+
557+
def get_column_from_pydantic_field(field: ModelField) -> Column: # type: ignore[type-arg]
558+
"""Returns an SQLAlchemy `Column` instance derived from a regular Pydantic `ModelField`."""
559+
kwargs = {"type_": get_sqlachemy_type(field), "nullable": _is_field_noneable(field)}
560+
default = get_field_info_default(field.field_info)
561+
if default is not Undefined:
562+
kwargs["default"] = default
563+
return Column(**kwargs)
564+
565+
566+
def get_column_from_field(field: ModelField) -> Column: # type: ignore[type-arg]
567+
"""Returns an SQLAlchemy `Column` instance derived from an SQLModel field."""
568+
if not isinstance(field.field_info, FieldInfo): # must be regular `PydanticFieldInfo`
569+
return get_column_from_pydantic_field(field)
570+
# We are dealing with the customized `FieldInfo` object:
571+
field_info: FieldInfo = field.field_info
572+
# The `sa_column` argument trumps everything: (for backwards compatibility)
573+
if isinstance(field_info.sa_column, Column):
574+
return field_info.sa_column
575+
args: List[SchemaItem] = []
576+
kwargs = field_info.get_defined_column_kwargs()
577+
# Only if no column type was explicitly defined, do we derive it here:
578+
kwargs.setdefault("type_", get_sqlachemy_type(field))
579+
# Only if nullability was not defined, do we infer it here:
580+
kwargs.setdefault(
581+
"nullable", not kwargs.get("primary_key", False) and _is_field_noneable(field)
582+
)
583+
# If a foreign key reference was explicitly named, construct the schema item here,
584+
# and make it the first positional argument for the `Column`:
585+
if field_info.foreign_key:
586+
args.append(ForeignKey(field_info.foreign_key))
587+
# All other positional column arguments are appended:
588+
args.extend(field_info.args)
589+
# Append `sa_column_args`: (for backwards compatibility)
590+
args.extend(field_info.sa_column_args)
591+
# Finally, let the `sa_column_kwargs` take precedence: (for backwards compatibility)
592+
kwargs.update(field_info.sa_column_kwargs)
593+
return Column(*args, **kwargs)
458594

459595

460596
class_registry = weakref.WeakValueDictionary() # type: ignore
@@ -647,9 +783,7 @@ def __tablename__(cls) -> str:
647783

648784

649785
def _is_field_noneable(field: ModelField) -> bool:
650-
if not field.required:
651-
# Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
652-
return field.allow_none and (
653-
field.shape != SHAPE_SINGLETON or not field.sub_fields
654-
)
655-
return False
786+
if field.required:
787+
return False
788+
# Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
789+
return field.allow_none and (field.shape != SHAPE_SINGLETON or not field.sub_fields)

0 commit comments

Comments
 (0)