Skip to content

Commit 165e9c5

Browse files
committed
Add AutoString support for Pydantic network types
1 parent 467d153 commit 165e9c5

File tree

5 files changed

+102
-7
lines changed

5 files changed

+102
-7
lines changed

docs/advanced/column-types.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,13 @@ In addition, the following types are stored as `VARCHAR`:
112112
* ipaddress.IPv6Address
113113
* ipaddress.IPv6Network
114114
* pathlib.Path
115+
* pydantic.networks.IPvAnyAddress
116+
* pydantic.networks.IPvAnyInterface
117+
* pydantic.networks.IPvAnyNetwork
115118
* pydantic.EmailStr
116119

120+
Note that while the column types for these are `VARCHAR`, values are not converted to and from strings.
121+
117122
### IP Addresses
118123

119124
IP Addresses from the <a href="https://docs.python.org/3/library/ipaddress.html" class="external-link" target="_blank">Python `ipaddress` module</a> are stored as text.
Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
import ipaddress
21
from datetime import UTC, datetime
32
from pathlib import Path
43
from uuid import UUID, uuid4
54

6-
from pydantic import EmailStr
7-
from sqlmodel import Field, SQLModel, create_engine
5+
from pydantic import EmailStr, IPvAnyAddress
6+
from sqlmodel import Field, Session, SQLModel, create_engine, select
87

98

109
class Avatar(SQLModel, table=True):
1110
id: UUID = Field(default_factory=uuid4, primary_key=True)
12-
source_ip_address: ipaddress.IPv4Address
11+
source_ip_address: IPvAnyAddress
1312
upload_location: Path
1413
uploaded_at: datetime = Field(default=datetime.now(tz=UTC))
1514
author_email: EmailStr
@@ -20,4 +19,55 @@ class Avatar(SQLModel, table=True):
2019

2120
engine = create_engine(sqlite_url, echo=True)
2221

23-
SQLModel.metadata.create_all(engine)
22+
23+
def create_db_and_tables():
24+
SQLModel.metadata.create_all(engine)
25+
26+
27+
def create_avatars():
28+
avatar_1 = Avatar(
29+
source_ip_address="127.0.0.1",
30+
upload_location="/uploads/1/123456789.jpg",
31+
author_email="[email protected]",
32+
)
33+
34+
avatar_2 = Avatar(
35+
source_ip_address="192.168.0.1",
36+
upload_location="/uploads/9/987654321.png",
37+
author_email="[email protected]",
38+
)
39+
40+
with Session(engine) as session:
41+
session.add(avatar_1)
42+
session.add(avatar_2)
43+
44+
session.commit()
45+
46+
47+
def read_avatars():
48+
with Session(engine) as session:
49+
statement = select(Avatar).where(Avatar.author_email == "[email protected]")
50+
result = session.exec(statement)
51+
avatar_1: Avatar = result.one()
52+
53+
print(
54+
"Avatar 1:",
55+
{
56+
"email": avatar_1.author_email,
57+
"email_type": type(avatar_1.author_email),
58+
"ip_address": avatar_1.source_ip_address,
59+
"ip_address_type": type(avatar_1.source_ip_address),
60+
"upload_location": avatar_1.upload_location,
61+
"upload_location_type": type(avatar_1.upload_location),
62+
},
63+
)
64+
65+
66+
def main():
67+
create_db_and_tables()
68+
create_avatars()
69+
read_avatars()
70+
71+
72+
if __name__ == "__main__":
73+
main()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ typer-cli = "^0.0.13"
5656
mkdocs-markdownextradata-plugin = ">=0.1.7,<0.3.0"
5757
# For column type tests
5858
wonderwords = "^2.2.0"
59-
geoalchemy2 = "^0.14.3"
59+
pydantic = {extras = ["email"], version = ">=1.10.13,<3.0.0"}
6060

6161
[build-system]
6262
requires = ["poetry-core"]

sqlmodel/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
overload,
2525
)
2626

27-
from pydantic import BaseModel, EmailStr
27+
from pydantic import BaseModel
2828
from pydantic.fields import FieldInfo as PydanticFieldInfo
29+
from pydantic.networks import EmailStr, IPvAnyAddress, IPvAnyInterface, IPvAnyNetwork
2930
from sqlalchemy import (
3031
Boolean,
3132
Column,
@@ -600,6 +601,12 @@ def get_sqlalchemy_type(field: Any) -> Any:
600601
return AutoString
601602
if issubclass(type_, ipaddress.IPv6Network):
602603
return AutoString
604+
if issubclass(type_, IPvAnyAddress):
605+
return AutoString
606+
if issubclass(type_, IPvAnyInterface):
607+
return AutoString
608+
if issubclass(type_, IPvAnyNetwork):
609+
return AutoString
603610
if issubclass(type_, Path):
604611
return AutoString
605612
if issubclass(type_, EmailStr):
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from unittest.mock import patch
2+
3+
from sqlmodel import create_engine
4+
5+
from ...conftest import get_testing_print_function
6+
7+
expected_calls = [
8+
[
9+
"Avatar 1:",
10+
{
11+
"email": "[email protected]",
12+
"email_type": str,
13+
"ip_address": "127.0.0.1",
14+
"ip_address_type": str,
15+
"upload_location": "/uploads/1/123456789.jpg",
16+
"upload_location_type": str,
17+
},
18+
],
19+
]
20+
21+
22+
def test_tutorial(clear_sqlmodel):
23+
from docs_src.advanced.column_types import tutorial003 as mod
24+
25+
mod.sqlite_url = "sqlite://"
26+
mod.engine = create_engine(mod.sqlite_url)
27+
calls = []
28+
29+
new_print = get_testing_print_function(calls)
30+
31+
with patch("builtins.print", new=new_print):
32+
mod.main()
33+
assert calls == expected_calls

0 commit comments

Comments
 (0)