diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 38c85915aa..9ca90a87af 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -52,7 +52,14 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import Literal, TypeAlias, deprecated, get_origin +from typing_extensions import ( + Annotated, + Literal, + TypeAlias, + deprecated, + get_args, + get_origin, +) from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, @@ -473,6 +480,16 @@ def Relationship( return relationship_info +def get_annotated_relationshipinfo(t: Any) -> Optional[RelationshipInfo]: + """Get the first RelationshipInfo from Annotated or None if not Annotated with RelationshipInfo.""" + if get_origin(t) is not Annotated: + return None + for a in get_args(t): + if isinstance(a, RelationshipInfo): + return a + return None + + @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] @@ -513,7 +530,12 @@ def __new__( else: dict_for_pydantic[k] = v for k, v in original_annotations.items(): - if k in relationships: + # check for `field: Annotated[Any, Relationship()]` + t = get_annotated_relationshipinfo(v) + if t: + relationships[k] = t + relationship_annotations[k] = get_args(v)[0] + elif k in relationships: relationship_annotations[k] = v else: pydantic_annotations[k] = v