Skip to content

Commit dc5d007

Browse files
guan404mingLee-Wjason810496
authored
feat: overwrite get_uri for JDBC (apache#48915)
* feat: overwrite `get_uri` for `JDBC` * fix: apply suggestions from code review Co-authored-by: Wei Lee <[email protected]> * fix: make string as format string --------- Co-authored-by: Wei Lee <[email protected]> Co-authored-by: LIU ZHE YOU <[email protected]>
1 parent c447ad3 commit dc5d007

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from contextlib import contextmanager
2323
from threading import RLock
2424
from typing import TYPE_CHECKING, Any
25+
from urllib.parse import quote_plus, urlencode
2526

2627
import jaydebeapi
2728
import jpype
@@ -220,3 +221,38 @@ def get_autocommit(self, conn: jaydebeapi.Connection) -> bool:
220221
with suppress_and_warn(jaydebeapi.Error, jpype.JException):
221222
return conn.jconn.getAutoCommit()
222223
return False
224+
225+
def get_uri(self) -> str:
226+
"""Get the connection URI for the JDBC connection."""
227+
conn = self.connection
228+
extra = conn.extra_dejson
229+
230+
scheme = extra.get("sqlalchemy_scheme")
231+
if not scheme:
232+
return conn.host
233+
234+
driver = extra.get("sqlalchemy_driver")
235+
uri_prefix = f"{scheme}+{driver}" if driver else scheme
236+
237+
auth_part = ""
238+
if conn.login:
239+
auth_part = quote_plus(conn.login)
240+
if conn.password:
241+
auth_part = f"{auth_part}:{quote_plus(conn.password)}"
242+
auth_part = f"{auth_part}@"
243+
244+
host_part = conn.host or "localhost"
245+
if conn.port:
246+
host_part = f"{host_part}:{conn.port}"
247+
248+
schema_part = f"/{quote_plus(conn.schema)}" if conn.schema else ""
249+
250+
uri = f"{uri_prefix}://{auth_part}{host_part}{schema_part}"
251+
252+
sqlalchemy_query = extra.get("sqlalchemy_query", {})
253+
if isinstance(sqlalchemy_query, dict):
254+
query_string = urlencode({k: str(v) for k, v in sqlalchemy_query.items() if v is not None})
255+
if query_string:
256+
uri = f"{uri}?{query_string}"
257+
258+
return uri

providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,70 @@ def call_get_conn():
309309
future.result() # This will raise OSError if get_conn isn't threadsafe
310310

311311
assert mock_connect.call_count == 10
312+
313+
@pytest.mark.parametrize(
314+
"params,expected_uri",
315+
[
316+
# JDBC URL fallback cases
317+
pytest.param(
318+
{"host": "jdbc:mysql://localhost:3306/test"},
319+
"jdbc:mysql://localhost:3306/test",
320+
id="jdbc-mysql",
321+
),
322+
pytest.param(
323+
{"host": "jdbc:postgresql://localhost:5432/test?user=user&password=pass%40word"},
324+
"jdbc:postgresql://localhost:5432/test?user=user&password=pass%40word",
325+
id="jdbc-postgresql",
326+
),
327+
pytest.param(
328+
{"host": "jdbc:oracle:thin:@localhost:1521:xe"},
329+
"jdbc:oracle:thin:@localhost:1521:xe",
330+
id="jdbc-oracle",
331+
),
332+
pytest.param(
333+
{"host": "jdbc:sqlserver://localhost:1433;databaseName=test;trustServerCertificate=true"},
334+
"jdbc:sqlserver://localhost:1433;databaseName=test;trustServerCertificate=true",
335+
id="jdbc-sqlserver",
336+
),
337+
# SQLAlchemy URI cases
338+
pytest.param(
339+
{
340+
"conn_params": {
341+
"extra": json.dumps(
342+
{"sqlalchemy_scheme": "mssql", "sqlalchemy_query": {"servicename": "test"}}
343+
)
344+
}
345+
},
346+
"mssql://login:password@host:1234/schema?servicename=test",
347+
id="sqlalchemy-scheme-with-query",
348+
),
349+
pytest.param(
350+
{
351+
"conn_params": {
352+
"extra": json.dumps(
353+
{"sqlalchemy_scheme": "postgresql", "sqlalchemy_driver": "psycopg2"}
354+
)
355+
}
356+
},
357+
"postgresql+psycopg2://login:password@host:1234/schema",
358+
id="sqlalchemy-scheme-with-driver",
359+
),
360+
pytest.param(
361+
{
362+
"login": "user@domain",
363+
"password": "pass/word",
364+
"schema": "my/db",
365+
"conn_params": {"extra": json.dumps({"sqlalchemy_scheme": "mysql"})},
366+
},
367+
"mysql://user%40domain:pass%2Fword@host:1234/my%2Fdb",
368+
id="sqlalchemy-with-encoding",
369+
),
370+
],
371+
)
372+
def test_get_uri(self, params, expected_uri):
373+
"""Test get_uri with different configurations including JDBC URLs and SQLAlchemy URIs."""
374+
valid_keys = {"host", "login", "password", "schema", "conn_params"}
375+
hook_params = {key: params[key] for key in valid_keys & params.keys()}
376+
377+
jdbc_hook = get_hook(**hook_params)
378+
assert jdbc_hook.get_uri() == expected_uri

0 commit comments

Comments
 (0)