1
1
import ipaddress
2
2
import uuid
3
+ import warnings
3
4
import weakref
4
5
from datetime import date , datetime , time , timedelta
5
6
from decimal import Decimal
38
39
from sqlalchemy .orm .attributes import set_attribute
39
40
from sqlalchemy .orm .decl_api import DeclarativeMeta
40
41
from sqlalchemy .orm .instrumentation import is_instrumented
41
- from sqlalchemy .sql .schema import MetaData
42
+ from sqlalchemy .sql .elements import TextClause
43
+ from sqlalchemy .sql .schema import FetchedValue , MetaData , SchemaItem
42
44
from sqlalchemy .sql .sqltypes import LargeBinary , Time
45
+ from sqlalchemy .sql .type_api import TypeEngine
43
46
44
47
from .sql .sqltypes import GUID , AutoString
45
48
@@ -57,35 +60,94 @@ def __dataclass_transform__(
57
60
58
61
59
62
class FieldInfo (PydanticFieldInfo ):
60
- def __init__ (self , default : Any = Undefined , ** kwargs : Any ) -> None :
61
- primary_key = kwargs .pop ("primary_key" , False )
62
- nullable = kwargs .pop ("nullable" , Undefined )
63
- foreign_key = kwargs .pop ("foreign_key" , Undefined )
64
- unique = kwargs .pop ("unique" , False )
65
- index = kwargs .pop ("index" , Undefined )
66
- sa_column = kwargs .pop ("sa_column" , Undefined )
67
- sa_column_args = kwargs .pop ("sa_column_args" , Undefined )
68
- sa_column_kwargs = kwargs .pop ("sa_column_kwargs" , Undefined )
69
- if sa_column is not Undefined :
70
- if sa_column_args is not Undefined :
71
- raise RuntimeError (
72
- "Passing sa_column_args is not supported when "
73
- "also passing a sa_column"
74
- )
75
- if sa_column_kwargs is not Undefined :
76
- raise RuntimeError (
77
- "Passing sa_column_kwargs is not supported when "
78
- "also passing a sa_column"
79
- )
80
- super ().__init__ (default = default , ** kwargs )
81
- self .primary_key = primary_key
82
- self .nullable = nullable
83
- self .foreign_key = foreign_key
84
- self .unique = unique
85
- self .index = index
86
- self .sa_column = sa_column
87
- self .sa_column_args = sa_column_args
88
- self .sa_column_kwargs = sa_column_kwargs
63
+
64
+ # In addition to the `PydanticFieldInfo` slots, set slots corresponding to parameters for the SQLAlchemy
65
+ # [Column](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column),
66
+ # along with any custom additions:
67
+ __slots__ = (
68
+ "name" ,
69
+ "type_" ,
70
+ "args" ,
71
+ "autoincrement" ,
72
+ # `default` omitted because that slot is defined on the base class
73
+ "doc" ,
74
+ "key" ,
75
+ "index" ,
76
+ "info" ,
77
+ "nullable" ,
78
+ "onupdate" ,
79
+ "primary_key" ,
80
+ "server_default" ,
81
+ "server_onupdate" ,
82
+ "quote" ,
83
+ "unique" ,
84
+ "system" ,
85
+ "comment" ,
86
+ "foreign_key" , # custom parameter for easier foreign key setting
87
+ # For backwards compatibility: (!?)
88
+ "sa_column" ,
89
+ "sa_column_args" ,
90
+ "sa_column_kwargs" ,
91
+ )
92
+
93
+ # Defined here for static type checkers:
94
+ name : Union [str , UndefinedType ]
95
+ type_ : Union [TypeEngine , UndefinedType ] # type: ignore[type-arg]
96
+ args : Sequence [SchemaItem ]
97
+ autoincrement : Union [bool , str ]
98
+ doc : Optional [str ]
99
+ key : Union [str , UndefinedType ]
100
+ index : Optional [bool ]
101
+ info : Union [Dict [str , Any ], UndefinedType ]
102
+ nullable : Union [bool , UndefinedType ]
103
+ onupdate : Any
104
+ primary_key : bool
105
+ server_default : Union [FetchedValue , str , TextClause , None ]
106
+ server_onupdate : Optional [FetchedValue ]
107
+ quote : Union [bool , None , UndefinedType ]
108
+ unique : Optional [bool ]
109
+ system : bool
110
+ comment : Optional [str ]
111
+
112
+ foreign_key : Optional [str ]
113
+
114
+ sa_column : Union [Column , UndefinedType ] # type: ignore[type-arg]
115
+ sa_column_args : Sequence [Any ]
116
+ sa_column_kwargs : Mapping [str , Any ]
117
+
118
+ def __init__ (self , ** kwargs : Any ) -> None :
119
+ # Split off all keyword-arguments corresponding to our new additional attributes:
120
+ new_kwargs = {param : kwargs .pop (param , Undefined ) for param in self .__slots__ }
121
+ # Pass the rest of the keyword-arguments to the Pydantic `FieldInfo.__init__`:
122
+ super ().__init__ (** kwargs )
123
+ # Set the other keyword-arguments as instance attributes:
124
+ for param , value in new_kwargs .items ():
125
+ setattr (self , param , value )
126
+
127
+ def get_defined_column_kwargs (self ) -> Dict [str , Any ]:
128
+ """
129
+ Returns a dictionary of keyword arguments for the SQLAlchemy `Column.__init__` method
130
+ derived from the corresponding attributes of the `FieldInfo` instance,
131
+ omitting all those that have been left undefined.
132
+ """
133
+ special = {
134
+ "args" ,
135
+ "foreign_key" ,
136
+ "sa_column" ,
137
+ "sa_column_args" ,
138
+ "sa_column_kwargs" ,
139
+ }
140
+ kwargs = {}
141
+ for key in self .__slots__ :
142
+ if key in special :
143
+ continue
144
+ value = getattr (self , key , Undefined )
145
+ if value is not Undefined :
146
+ kwargs [key ] = value
147
+ default = get_field_info_default (self )
148
+ if default is not Undefined :
149
+ kwargs ["default" ] = default
150
+ return kwargs
89
151
90
152
91
153
class RelationshipInfo (Representation ):
@@ -117,8 +179,9 @@ def __init__(
117
179
118
180
119
181
def Field (
120
- default : Any = Undefined ,
121
- * ,
182
+ * args : SchemaItem , # positional arguments for SQLAlchemy `Column.__init__`
183
+ default : Any = Undefined , # meaningful for both Pydantic and SQLAlchemy
184
+ # The following are specific to Pydantic:
122
185
default_factory : Optional [NoArgAnyCallable ] = None ,
123
186
alias : Optional [str ] = None ,
124
187
title : Optional [str ] = None ,
@@ -141,19 +204,78 @@ def Field(
141
204
max_length : Optional [int ] = None ,
142
205
allow_mutation : bool = True ,
143
206
regex : Optional [str ] = None ,
207
+ # The following are specific to SQLAlchemy:
208
+ name : Optional [str ] = None ,
209
+ type_ : Union [TypeEngine , UndefinedType ] = Undefined , # type: ignore[type-arg]
210
+ autoincrement : Union [bool , str ] = "auto" ,
211
+ doc : Optional [str ] = None ,
212
+ key : Union [str , UndefinedType ] = Undefined , # `Column` default is `name`
213
+ index : Optional [bool ] = None ,
214
+ info : Union [Dict [str , Any ], UndefinedType ] = Undefined , # `Column` default is `{}`
215
+ nullable : Union [
216
+ bool , UndefinedType
217
+ ] = Undefined , # `Column` default depends on `primary_key`
218
+ onupdate : Any = None ,
144
219
primary_key : bool = False ,
145
- foreign_key : Optional [Any ] = None ,
146
- unique : bool = False ,
147
- nullable : Union [bool , UndefinedType ] = Undefined ,
148
- index : Union [bool , UndefinedType ] = Undefined ,
149
- sa_column : Union [Column , UndefinedType ] = Undefined , # type: ignore
150
- sa_column_args : Union [Sequence [Any ], UndefinedType ] = Undefined ,
151
- sa_column_kwargs : Union [Mapping [str , Any ], UndefinedType ] = Undefined ,
220
+ server_default : Union [FetchedValue , str , TextClause , None ] = None ,
221
+ server_onupdate : Optional [FetchedValue ] = None ,
222
+ quote : Union [
223
+ bool , None , UndefinedType
224
+ ] = Undefined , # `Column` default not (fully) defined
225
+ unique : Optional [bool ] = None ,
226
+ system : bool = False ,
227
+ comment : Optional [str ] = None ,
228
+ foreign_key : Optional [str ] = None ,
229
+ # For backwards compatibility: (!?)
230
+ sa_column : Union [Column , UndefinedType ] = Undefined , # type: ignore[type-arg]
231
+ sa_column_args : Sequence [Any ] = (),
232
+ sa_column_kwargs : Optional [Mapping [str , Any ]] = None ,
233
+ # Extra:
152
234
schema_extra : Optional [Dict [str , Any ]] = None ,
153
- ) -> Any :
235
+ ) -> FieldInfo :
236
+ """
237
+ Constructor for explicitly defining the attributes of a model field.
238
+
239
+ The resulting field information is used both for Pydantic model validation **and** for SQLAlchemy column definition.
240
+
241
+ The following parameters are passed to initialize the Pydantic `FieldInfo`
242
+ (see [`Field` docs](https://pydantic-docs.helpmanual.io/usage/schema/#field-customization)):
243
+ `default`, `default_factory`, `alias`, `title`, `description`, `exclude`, `include`, `const`, `gt`, `ge`,
244
+ `lt`, `le`, `multiple_of`, `min_items`, `max_items`, `min_length`, `max_length`, `allow_mutation`, `regex`.
245
+
246
+ These parameters are passed to initialize the SQLAlchemy
247
+ [`Column`](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column):
248
+ `*args`, `name`, `type_`, `autoincrement`, `doc`, `key`, `index`, `info`, `nullable`, `onupdate`, `primary_key`,
249
+ `server_default`, `server_onupdate`, `quote`, `unique`, `system`, `comment`.
250
+
251
+ If provided, the `default_factory` argument is passed as `default` to the `Column` constructor;
252
+ otherwise, if the `default` argument is provided, it is passed to the `Column` constructor.
253
+
254
+ Note:
255
+ The SQLAlchemy `Column` default for `type_` is actually `None`, but it makes more sense to leave it undefined,
256
+ unless an argument is passed explicitly. If someone explicitly wants to pass `None` to set the `NullType` for
257
+ whatever reason, they will be able to do so.
258
+ (see [`type_`](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column.params.type_))
259
+ """
154
260
current_schema_extra = schema_extra or {}
261
+ # For backwards compatibility: (!?)
262
+ if sa_column is not Undefined :
263
+ warnings .warn (
264
+ "Specifying `sa_column` overrides all other column arguments" ,
265
+ DeprecationWarning ,
266
+ )
267
+ if sa_column_args != ():
268
+ warnings .warn (
269
+ "Instead of `sa_column_args` use positional arguments" ,
270
+ DeprecationWarning ,
271
+ )
272
+ if sa_column_kwargs is not None :
273
+ warnings .warn (
274
+ "`sa_column_kwargs` takes precedence over other keyword-arguments" ,
275
+ DeprecationWarning ,
276
+ )
155
277
field_info = FieldInfo (
156
- default ,
278
+ default = default ,
157
279
default_factory = default_factory ,
158
280
alias = alias ,
159
281
title = title ,
@@ -172,14 +294,27 @@ def Field(
172
294
max_length = max_length ,
173
295
allow_mutation = allow_mutation ,
174
296
regex = regex ,
297
+ name = name ,
298
+ type_ = type_ ,
299
+ args = args ,
300
+ autoincrement = autoincrement ,
301
+ doc = doc ,
302
+ key = key ,
303
+ index = index ,
304
+ info = info ,
305
+ nullable = nullable ,
306
+ onupdate = onupdate ,
175
307
primary_key = primary_key ,
176
- foreign_key = foreign_key ,
308
+ server_default = server_default ,
309
+ server_onupdate = server_onupdate ,
310
+ quote = quote ,
177
311
unique = unique ,
178
- nullable = nullable ,
179
- index = index ,
312
+ system = system ,
313
+ comment = comment ,
314
+ foreign_key = foreign_key ,
180
315
sa_column = sa_column ,
181
316
sa_column_args = sa_column_args ,
182
- sa_column_kwargs = sa_column_kwargs ,
317
+ sa_column_kwargs = sa_column_kwargs or {} ,
183
318
** current_schema_extra ,
184
319
)
185
320
field_info ._validate ()
@@ -414,47 +549,48 @@ def get_sqlachemy_type(field: ModelField) -> Any:
414
549
raise ValueError (f"The field { field .name } has no matching SQLAlchemy type" )
415
550
416
551
417
- def get_column_from_field (field : ModelField ) -> Column : # type: ignore
418
- sa_column = getattr (field .field_info , "sa_column" , Undefined )
419
- if isinstance (sa_column , Column ):
420
- return sa_column
421
- sa_type = get_sqlachemy_type (field )
422
- primary_key = getattr (field .field_info , "primary_key" , False )
423
- index = getattr (field .field_info , "index" , Undefined )
424
- if index is Undefined :
425
- index = False
426
- nullable = not primary_key and _is_field_noneable (field )
427
- # Override derived nullability if the nullable property is set explicitly
428
- # on the field
429
- if hasattr (field .field_info , "nullable" ):
430
- field_nullable = getattr (field .field_info , "nullable" )
431
- if field_nullable != Undefined :
432
- nullable = field_nullable
433
- args = []
434
- foreign_key = getattr (field .field_info , "foreign_key" , None )
435
- unique = getattr (field .field_info , "unique" , False )
436
- if foreign_key :
437
- args .append (ForeignKey (foreign_key ))
438
- kwargs = {
439
- "primary_key" : primary_key ,
440
- "nullable" : nullable ,
441
- "index" : index ,
442
- "unique" : unique ,
443
- }
444
- sa_default = Undefined
445
- if field .field_info .default_factory :
446
- sa_default = field .field_info .default_factory
447
- elif field .field_info .default is not Undefined :
448
- sa_default = field .field_info .default
449
- if sa_default is not Undefined :
450
- kwargs ["default" ] = sa_default
451
- sa_column_args = getattr (field .field_info , "sa_column_args" , Undefined )
452
- if sa_column_args is not Undefined :
453
- args .extend (list (cast (Sequence [Any ], sa_column_args )))
454
- sa_column_kwargs = getattr (field .field_info , "sa_column_kwargs" , Undefined )
455
- if sa_column_kwargs is not Undefined :
456
- kwargs .update (cast (Dict [Any , Any ], sa_column_kwargs ))
457
- return Column (sa_type , * args , ** kwargs ) # type: ignore
552
+ def get_field_info_default (info : PydanticFieldInfo ) -> Any :
553
+ """Returns the `default_factory` if set, otherwise the `default` value."""
554
+ return info .default_factory if info .default_factory is not None else info .default
555
+
556
+
557
+ def get_column_from_pydantic_field (field : ModelField ) -> Column : # type: ignore[type-arg]
558
+ """Returns an SQLAlchemy `Column` instance derived from a regular Pydantic `ModelField`."""
559
+ kwargs = {"type_" : get_sqlachemy_type (field ), "nullable" : _is_field_noneable (field )}
560
+ default = get_field_info_default (field .field_info )
561
+ if default is not Undefined :
562
+ kwargs ["default" ] = default
563
+ return Column (** kwargs )
564
+
565
+
566
+ def get_column_from_field (field : ModelField ) -> Column : # type: ignore[type-arg]
567
+ """Returns an SQLAlchemy `Column` instance derived from an SQLModel field."""
568
+ if not isinstance (field .field_info , FieldInfo ): # must be regular `PydanticFieldInfo`
569
+ return get_column_from_pydantic_field (field )
570
+ # We are dealing with the customized `FieldInfo` object:
571
+ field_info : FieldInfo = field .field_info
572
+ # The `sa_column` argument trumps everything: (for backwards compatibility)
573
+ if isinstance (field_info .sa_column , Column ):
574
+ return field_info .sa_column
575
+ args : List [SchemaItem ] = []
576
+ kwargs = field_info .get_defined_column_kwargs ()
577
+ # Only if no column type was explicitly defined, do we derive it here:
578
+ kwargs .setdefault ("type_" , get_sqlachemy_type (field ))
579
+ # Only if nullability was not defined, do we infer it here:
580
+ kwargs .setdefault (
581
+ "nullable" , not kwargs .get ("primary_key" , False ) and _is_field_noneable (field )
582
+ )
583
+ # If a foreign key reference was explicitly named, construct the schema item here,
584
+ # and make it the first positional argument for the `Column`:
585
+ if field_info .foreign_key :
586
+ args .append (ForeignKey (field_info .foreign_key ))
587
+ # All other positional column arguments are appended:
588
+ args .extend (field_info .args )
589
+ # Append `sa_column_args`: (for backwards compatibility)
590
+ args .extend (field_info .sa_column_args )
591
+ # Finally, let the `sa_column_kwargs` take precedence: (for backwards compatibility)
592
+ kwargs .update (field_info .sa_column_kwargs )
593
+ return Column (* args , ** kwargs )
458
594
459
595
460
596
class_registry = weakref .WeakValueDictionary () # type: ignore
@@ -647,9 +783,7 @@ def __tablename__(cls) -> str:
647
783
648
784
649
785
def _is_field_noneable (field : ModelField ) -> bool :
650
- if not field .required :
651
- # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
652
- return field .allow_none and (
653
- field .shape != SHAPE_SINGLETON or not field .sub_fields
654
- )
655
- return False
786
+ if field .required :
787
+ return False
788
+ # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
789
+ return field .allow_none and (field .shape != SHAPE_SINGLETON or not field .sub_fields )
0 commit comments