25
25
26
26
import aiohttp
27
27
import requests
28
+ from aiohttp import ClientConnectionError , ClientResponseError
28
29
from cryptography .hazmat .backends import default_backend
29
30
from cryptography .hazmat .primitives import serialization
31
+ from requests .exceptions import ConnectionError , HTTPError , Timeout
32
+ from tenacity import (
33
+ AsyncRetrying ,
34
+ Retrying ,
35
+ before_sleep_log ,
36
+ retry_if_exception ,
37
+ stop_after_attempt ,
38
+ wait_exponential ,
39
+ )
30
40
31
41
from airflow .exceptions import AirflowException , AirflowProviderDeprecationWarning
32
42
from airflow .providers .snowflake .hooks .snowflake import SnowflakeHook
@@ -65,6 +75,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
65
75
:param token_life_time: lifetime of the JWT Token in timedelta
66
76
:param token_renewal_delta: Renewal time of the JWT Token in timedelta
67
77
:param deferrable: Run operator in the deferrable mode.
78
+ :param api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.
68
79
"""
69
80
70
81
LIFETIME = timedelta (minutes = 59 ) # The tokens will have a 59 minute lifetime
@@ -75,15 +86,27 @@ def __init__(
75
86
snowflake_conn_id : str ,
76
87
token_life_time : timedelta = LIFETIME ,
77
88
token_renewal_delta : timedelta = RENEWAL_DELTA ,
89
+ api_retry_args : dict [Any , Any ] | None = None , # Optional retry arguments passed to tenacity.retry
78
90
* args : Any ,
79
91
** kwargs : Any ,
80
92
):
81
93
self .snowflake_conn_id = snowflake_conn_id
82
94
self .token_life_time = token_life_time
83
95
self .token_renewal_delta = token_renewal_delta
96
+
84
97
super ().__init__ (snowflake_conn_id , * args , ** kwargs )
85
98
self .private_key : Any = None
86
99
100
+ self .retry_config = {
101
+ "retry" : retry_if_exception (self ._should_retry_on_error ),
102
+ "wait" : wait_exponential (multiplier = 1 , min = 1 , max = 60 ),
103
+ "stop" : stop_after_attempt (5 ),
104
+ "before_sleep" : before_sleep_log (self .log , log_level = 20 ), # INFO level
105
+ "reraise" : True ,
106
+ }
107
+ if api_retry_args :
108
+ self .retry_config .update (api_retry_args )
109
+
87
110
def get_private_key (self ) -> None :
88
111
"""Get the private key from snowflake connection."""
89
112
conn = self .get_connection (self .snowflake_conn_id )
@@ -168,13 +191,8 @@ def execute_query(
168
191
"query_tag" : query_tag ,
169
192
},
170
193
}
171
- response = requests .post (url , json = data , headers = headers , params = params )
172
- try :
173
- response .raise_for_status ()
174
- except requests .exceptions .HTTPError as e : # pragma: no cover
175
- msg = f"Response: { e .response .content .decode ()} Status Code: { e .response .status_code } "
176
- raise AirflowException (msg )
177
- json_response = response .json ()
194
+
195
+ _ , json_response = self ._make_api_call_with_retries ("POST" , url , headers , params , data )
178
196
self .log .info ("Snowflake SQL POST API response: %s" , json_response )
179
197
if "statementHandles" in json_response :
180
198
self .query_ids = json_response ["statementHandles" ]
@@ -259,13 +277,10 @@ def check_query_output(self, query_ids: list[str]) -> None:
259
277
"""
260
278
for query_id in query_ids :
261
279
header , params , url = self .get_request_url_header_params (query_id )
262
- try :
263
- response = requests .get (url , headers = header , params = params )
264
- response .raise_for_status ()
265
- self .log .info (response .json ())
266
- except requests .exceptions .HTTPError as e :
267
- msg = f"Response: { e .response .content .decode ()} , Status Code: { e .response .status_code } "
268
- raise AirflowException (msg )
280
+ _ , response_json = self ._make_api_call_with_retries (
281
+ method = "GET" , url = url , headers = header , params = params
282
+ )
283
+ self .log .info (response_json )
269
284
270
285
def _process_response (self , status_code , resp ):
271
286
self .log .info ("Snowflake SQL GET statements status API response: %s" , resp )
@@ -295,9 +310,7 @@ def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]:
295
310
"""
296
311
self .log .info ("Retrieving status for query id %s" , query_id )
297
312
header , params , url = self .get_request_url_header_params (query_id )
298
- response = requests .get (url , params = params , headers = header )
299
- status_code = response .status_code
300
- resp = response .json ()
313
+ status_code , resp = self ._make_api_call_with_retries ("GET" , url , header , params )
301
314
return self ._process_response (status_code , resp )
302
315
303
316
async def get_sql_api_query_status_async (self , query_id : str ) -> dict [str , str | list [str ]]:
@@ -308,10 +321,85 @@ async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str |
308
321
"""
309
322
self .log .info ("Retrieving status for query id %s" , query_id )
310
323
header , params , url = self .get_request_url_header_params (query_id )
311
- async with (
312
- aiohttp .ClientSession (headers = header ) as session ,
313
- session .get (url , params = params ) as response ,
324
+ status_code , resp = await self ._make_api_call_with_retries_async ("GET" , url , header , params )
325
+ return self ._process_response (status_code , resp )
326
+
327
+ @staticmethod
328
+ def _should_retry_on_error (exception ) -> bool :
329
+ """
330
+ Determine if the exception should trigger a retry based on error type and status code.
331
+
332
+ Retries on HTTP errors 429 (Too Many Requests), 503 (Service Unavailable),
333
+ and 504 (Gateway Timeout) as recommended by Snowflake error handling docs.
334
+ Retries on connection errors and timeouts.
335
+
336
+ :param exception: The exception to check
337
+ :return: True if the request should be retried, False otherwise
338
+ """
339
+ if isinstance (exception , HTTPError ):
340
+ return exception .response .status_code in [429 , 503 , 504 ]
341
+ if isinstance (exception , ClientResponseError ):
342
+ return exception .status in [429 , 503 , 504 ]
343
+ if isinstance (
344
+ exception ,
345
+ (
346
+ ConnectionError ,
347
+ Timeout ,
348
+ ClientConnectionError ,
349
+ ),
314
350
):
315
- status_code = response .status
316
- resp = await response .json ()
317
- return self ._process_response (status_code , resp )
351
+ return True
352
+ return False
353
+
354
+ def _make_api_call_with_retries (
355
+ self , method : str , url : str , headers : dict , params : dict | None = None , json : dict | None = None
356
+ ):
357
+ """
358
+ Make an API call to the Snowflake SQL API with retry logic for specific HTTP errors.
359
+
360
+ Error handling implemented based on Snowflake error handling docs:
361
+ https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors
362
+
363
+ :param method: The HTTP method to use for the API call.
364
+ :param url: The URL for the API endpoint.
365
+ :param headers: The headers to include in the API call.
366
+ :param params: (Optional) The query parameters to include in the API call.
367
+ :param data: (Optional) The data to include in the API call.
368
+ :return: The response object from the API call.
369
+ """
370
+ with requests .Session () as session :
371
+ for attempt in Retrying (** self .retry_config ): # type: ignore
372
+ with attempt :
373
+ if method .upper () in ("GET" , "POST" ):
374
+ response = session .request (
375
+ method = method .lower (), url = url , headers = headers , params = params , json = json
376
+ )
377
+ else :
378
+ raise ValueError (f"Unsupported HTTP method: { method } " )
379
+ response .raise_for_status ()
380
+ return response .status_code , response .json ()
381
+
382
+ async def _make_api_call_with_retries_async (self , method , url , headers , params = None ):
383
+ """
384
+ Make an API call to the Snowflake SQL API asynchronously with retry logic for specific HTTP errors.
385
+
386
+ Error handling implemented based on Snowflake error handling docs:
387
+ https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors
388
+
389
+ :param method: The HTTP method to use for the API call. Only GET is supported as is synchronous.
390
+ :param url: The URL for the API endpoint.
391
+ :param headers: The headers to include in the API call.
392
+ :param params: (Optional) The query parameters to include in the API call.
393
+ :return: The response object from the API call.
394
+ """
395
+ async with aiohttp .ClientSession (headers = headers ) as session :
396
+ async for attempt in AsyncRetrying (** self .retry_config ): # type: ignore
397
+ with attempt :
398
+ if method .upper () == "GET" :
399
+ async with session .request (method = method .lower (), url = url , params = params ) as response :
400
+ response .raise_for_status ()
401
+ # Return status and json content for async processing
402
+ content = await response .json ()
403
+ return response .status , content
404
+ else :
405
+ raise ValueError (f"Unsupported HTTP method: { method } " )
0 commit comments