Skip to content

Commit 25be79a

Browse files
authored
Snowflake SQL API hook Retry Logic (apache#51463)
* added retry logic for sync and async requests to snowflake api * first draft of unit tests * unit tests for hook and retries * remove comment * update sync request to use request.request * mypy fixes * updated sync and async api call methods to use tenacity context manager * update unit test with correct method * retry args docs * reorder so self.log is initialized
1 parent 35bc037 commit 25be79a

File tree

3 files changed

+518
-60
lines changed

3 files changed

+518
-60
lines changed

providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py

Lines changed: 111 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,18 @@
2525

2626
import aiohttp
2727
import requests
28+
from aiohttp import ClientConnectionError, ClientResponseError
2829
from cryptography.hazmat.backends import default_backend
2930
from cryptography.hazmat.primitives import serialization
31+
from requests.exceptions import ConnectionError, HTTPError, Timeout
32+
from tenacity import (
33+
AsyncRetrying,
34+
Retrying,
35+
before_sleep_log,
36+
retry_if_exception,
37+
stop_after_attempt,
38+
wait_exponential,
39+
)
3040

3141
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
3242
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
@@ -65,6 +75,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
6575
:param token_life_time: lifetime of the JWT Token in timedelta
6676
:param token_renewal_delta: Renewal time of the JWT Token in timedelta
6777
:param deferrable: Run operator in the deferrable mode.
78+
:param api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.
6879
"""
6980

7081
LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute lifetime
@@ -75,15 +86,27 @@ def __init__(
7586
snowflake_conn_id: str,
7687
token_life_time: timedelta = LIFETIME,
7788
token_renewal_delta: timedelta = RENEWAL_DELTA,
89+
api_retry_args: dict[Any, Any] | None = None, # Optional retry arguments passed to tenacity.retry
7890
*args: Any,
7991
**kwargs: Any,
8092
):
8193
self.snowflake_conn_id = snowflake_conn_id
8294
self.token_life_time = token_life_time
8395
self.token_renewal_delta = token_renewal_delta
96+
8497
super().__init__(snowflake_conn_id, *args, **kwargs)
8598
self.private_key: Any = None
8699

100+
self.retry_config = {
101+
"retry": retry_if_exception(self._should_retry_on_error),
102+
"wait": wait_exponential(multiplier=1, min=1, max=60),
103+
"stop": stop_after_attempt(5),
104+
"before_sleep": before_sleep_log(self.log, log_level=20), # INFO level
105+
"reraise": True,
106+
}
107+
if api_retry_args:
108+
self.retry_config.update(api_retry_args)
109+
87110
def get_private_key(self) -> None:
88111
"""Get the private key from snowflake connection."""
89112
conn = self.get_connection(self.snowflake_conn_id)
@@ -168,13 +191,8 @@ def execute_query(
168191
"query_tag": query_tag,
169192
},
170193
}
171-
response = requests.post(url, json=data, headers=headers, params=params)
172-
try:
173-
response.raise_for_status()
174-
except requests.exceptions.HTTPError as e: # pragma: no cover
175-
msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
176-
raise AirflowException(msg)
177-
json_response = response.json()
194+
195+
_, json_response = self._make_api_call_with_retries("POST", url, headers, params, data)
178196
self.log.info("Snowflake SQL POST API response: %s", json_response)
179197
if "statementHandles" in json_response:
180198
self.query_ids = json_response["statementHandles"]
@@ -259,13 +277,10 @@ def check_query_output(self, query_ids: list[str]) -> None:
259277
"""
260278
for query_id in query_ids:
261279
header, params, url = self.get_request_url_header_params(query_id)
262-
try:
263-
response = requests.get(url, headers=header, params=params)
264-
response.raise_for_status()
265-
self.log.info(response.json())
266-
except requests.exceptions.HTTPError as e:
267-
msg = f"Response: {e.response.content.decode()}, Status Code: {e.response.status_code}"
268-
raise AirflowException(msg)
280+
_, response_json = self._make_api_call_with_retries(
281+
method="GET", url=url, headers=header, params=params
282+
)
283+
self.log.info(response_json)
269284

270285
def _process_response(self, status_code, resp):
271286
self.log.info("Snowflake SQL GET statements status API response: %s", resp)
@@ -295,9 +310,7 @@ def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]:
295310
"""
296311
self.log.info("Retrieving status for query id %s", query_id)
297312
header, params, url = self.get_request_url_header_params(query_id)
298-
response = requests.get(url, params=params, headers=header)
299-
status_code = response.status_code
300-
resp = response.json()
313+
status_code, resp = self._make_api_call_with_retries("GET", url, header, params)
301314
return self._process_response(status_code, resp)
302315

303316
async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]:
@@ -308,10 +321,85 @@ async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str |
308321
"""
309322
self.log.info("Retrieving status for query id %s", query_id)
310323
header, params, url = self.get_request_url_header_params(query_id)
311-
async with (
312-
aiohttp.ClientSession(headers=header) as session,
313-
session.get(url, params=params) as response,
324+
status_code, resp = await self._make_api_call_with_retries_async("GET", url, header, params)
325+
return self._process_response(status_code, resp)
326+
327+
@staticmethod
328+
def _should_retry_on_error(exception) -> bool:
329+
"""
330+
Determine if the exception should trigger a retry based on error type and status code.
331+
332+
Retries on HTTP errors 429 (Too Many Requests), 503 (Service Unavailable),
333+
and 504 (Gateway Timeout) as recommended by Snowflake error handling docs.
334+
Retries on connection errors and timeouts.
335+
336+
:param exception: The exception to check
337+
:return: True if the request should be retried, False otherwise
338+
"""
339+
if isinstance(exception, HTTPError):
340+
return exception.response.status_code in [429, 503, 504]
341+
if isinstance(exception, ClientResponseError):
342+
return exception.status in [429, 503, 504]
343+
if isinstance(
344+
exception,
345+
(
346+
ConnectionError,
347+
Timeout,
348+
ClientConnectionError,
349+
),
314350
):
315-
status_code = response.status
316-
resp = await response.json()
317-
return self._process_response(status_code, resp)
351+
return True
352+
return False
353+
354+
def _make_api_call_with_retries(
355+
self, method: str, url: str, headers: dict, params: dict | None = None, json: dict | None = None
356+
):
357+
"""
358+
Make an API call to the Snowflake SQL API with retry logic for specific HTTP errors.
359+
360+
Error handling implemented based on Snowflake error handling docs:
361+
https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors
362+
363+
:param method: The HTTP method to use for the API call.
364+
:param url: The URL for the API endpoint.
365+
:param headers: The headers to include in the API call.
366+
:param params: (Optional) The query parameters to include in the API call.
367+
:param data: (Optional) The data to include in the API call.
368+
:return: The response object from the API call.
369+
"""
370+
with requests.Session() as session:
371+
for attempt in Retrying(**self.retry_config): # type: ignore
372+
with attempt:
373+
if method.upper() in ("GET", "POST"):
374+
response = session.request(
375+
method=method.lower(), url=url, headers=headers, params=params, json=json
376+
)
377+
else:
378+
raise ValueError(f"Unsupported HTTP method: {method}")
379+
response.raise_for_status()
380+
return response.status_code, response.json()
381+
382+
async def _make_api_call_with_retries_async(self, method, url, headers, params=None):
383+
"""
384+
Make an API call to the Snowflake SQL API asynchronously with retry logic for specific HTTP errors.
385+
386+
Error handling implemented based on Snowflake error handling docs:
387+
https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors
388+
389+
:param method: The HTTP method to use for the API call. Only GET is supported as is synchronous.
390+
:param url: The URL for the API endpoint.
391+
:param headers: The headers to include in the API call.
392+
:param params: (Optional) The query parameters to include in the API call.
393+
:return: The response object from the API call.
394+
"""
395+
async with aiohttp.ClientSession(headers=headers) as session:
396+
async for attempt in AsyncRetrying(**self.retry_config): # type: ignore
397+
with attempt:
398+
if method.upper() == "GET":
399+
async with session.request(method=method.lower(), url=url, params=params) as response:
400+
response.raise_for_status()
401+
# Return status and json content for async processing
402+
content = await response.json()
403+
return response.status, content
404+
else:
405+
raise ValueError(f"Unsupported HTTP method: {method}")

providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
355355
When executing the statement, Snowflake replaces placeholders (? and :name) in
356356
the statement with these specified values.
357357
:param deferrable: Run operator in the deferrable mode.
358+
:param snowflake_api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.
358359
"""
359360

360361
LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime
@@ -381,6 +382,7 @@ def __init__(
381382
token_renewal_delta: timedelta = RENEWAL_DELTA,
382383
bindings: dict[str, Any] | None = None,
383384
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
385+
snowflake_api_retry_args: dict[str, Any] | None = None,
384386
**kwargs: Any,
385387
) -> None:
386388
self.snowflake_conn_id = snowflake_conn_id
@@ -390,6 +392,7 @@ def __init__(
390392
self.token_renewal_delta = token_renewal_delta
391393
self.bindings = bindings
392394
self.execute_async = False
395+
self.snowflake_api_retry_args = snowflake_api_retry_args or {}
393396
self.deferrable = deferrable
394397
self.query_ids: list[str] = []
395398
if any([warehouse, database, role, schema, authenticator, session_parameters]): # pragma: no cover
@@ -412,6 +415,7 @@ def _hook(self):
412415
token_life_time=self.token_life_time,
413416
token_renewal_delta=self.token_renewal_delta,
414417
deferrable=self.deferrable,
418+
api_retry_args=self.snowflake_api_retry_args,
415419
**self.hook_params,
416420
)
417421

0 commit comments

Comments
 (0)