Skip to content

Commit e3c37ee

Browse files
authored
Fixes context propagation, adds context tests (#1369)
1 parent e3bfcf8 commit e3c37ee

File tree

7 files changed

+213
-6
lines changed

7 files changed

+213
-6
lines changed

aiobotocore/context.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from contextlib import asynccontextmanager
22
from copy import deepcopy
33
from functools import wraps
4-
from inspect import iscoroutinefunction
54

65
from botocore.context import (
76
ClientContext,
@@ -10,6 +9,8 @@
109
set_context,
1110
)
1211

12+
from ._helpers import resolve_awaitable
13+
1314

1415
@asynccontextmanager
1516
async def start_as_current_context(ctx=None):
@@ -31,10 +32,7 @@ def decorator(func):
3132
async def wrapper(*args, **kwargs):
3233
async with start_as_current_context():
3334
if hook:
34-
if iscoroutinefunction(hook):
35-
await hook()
36-
else:
37-
hook()
35+
await resolve_awaitable(hook())
3836
return await func(*args, **kwargs)
3937

4038
return wrapper

aiobotocore/paginate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1+
from functools import partial
2+
13
import aioitertools
24
import jmespath
35
from botocore.exceptions import PaginationError
46
from botocore.paginate import PageIterator, Paginator
7+
from botocore.useragent import register_feature_id
58
from botocore.utils import merge_dicts, set_value_from_jmespath
69

10+
from .context import with_current_context
11+
712

813
class AioPageIterator(PageIterator):
914
def __aiter__(self):
1015
return self.__anext__()
1116

17+
@with_current_context(partial(register_feature_id, 'PAGINATOR'))
18+
async def _make_request(self, current_kwargs):
19+
return await self._method(**current_kwargs)
20+
1221
async def __anext__(self):
1322
current_kwargs = self._op_kwargs
1423
previous_next_token = None

aiobotocore/waiter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from functools import partial
33

44
# WaiterModel is required for client.py import
5-
from botocore.context import with_current_context
65
from botocore.docs.docstring import WaiterDocstring
76
from botocore.exceptions import ClientError
87
from botocore.useragent import register_feature_id
@@ -19,6 +18,8 @@
1918
xform_name,
2019
)
2120

21+
from .context import with_current_context
22+
2223

2324
def create_waiter_with_client(waiter_name, waiter_model, client):
2425
"""

tests/botocore_tests/__init__.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from botocore.compat import parse_qs, urlparse
2525

2626
import aiobotocore.session
27+
from aiobotocore.awsrequest import AioAWSResponse
2728

2829
_LOADER = botocore.loaders.Loader()
2930

@@ -93,3 +94,72 @@ def assert_url_equal(url1, url2):
9394
assert parts1.hostname == parts2.hostname
9495
assert parts1.port == parts2.port
9596
assert parse_qs(parts1.query) == parse_qs(parts2.query)
97+
98+
99+
class HTTPStubberException(Exception):
100+
pass
101+
102+
103+
class BaseHTTPStubber:
104+
class AsyncFileWrapper:
105+
def __init__(self, body: bytes):
106+
self._body = body
107+
108+
async def read(self):
109+
return self._body
110+
111+
def __init__(self, obj_with_event_emitter, strict=True):
112+
self.reset()
113+
self._strict = strict
114+
self._obj_with_event_emitter = obj_with_event_emitter
115+
116+
def reset(self):
117+
self.requests = []
118+
self.responses = []
119+
120+
def add_response(
121+
self, url='https://example.com', status=200, headers=None, body=b''
122+
):
123+
if headers is None:
124+
headers = {}
125+
126+
response = AioAWSResponse(
127+
url, status, headers, self.AsyncFileWrapper(body)
128+
)
129+
self.responses.append(response)
130+
131+
@property
132+
def _events(self):
133+
raise NotImplementedError('_events')
134+
135+
def start(self):
136+
self._events.register('before-send', self)
137+
138+
def stop(self):
139+
self._events.unregister('before-send', self)
140+
141+
def __enter__(self):
142+
self.start()
143+
return self
144+
145+
def __exit__(self, exc_type, exc_value, traceback):
146+
self.stop()
147+
148+
def __call__(self, request, **kwargs):
149+
self.requests.append(request)
150+
if self.responses:
151+
response = self.responses.pop(0)
152+
if isinstance(response, Exception):
153+
raise response
154+
else:
155+
return response
156+
elif self._strict:
157+
raise HTTPStubberException('Insufficient responses')
158+
else:
159+
return None
160+
161+
162+
class ClientHTTPStubber(BaseHTTPStubber):
163+
@property
164+
def _events(self):
165+
return self._obj_with_event_emitter.meta.events

tests/botocore_tests/functional/__init__.py

Whitespace-only changes.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import asyncio
2+
3+
from aiobotocore.session import AioSession
4+
5+
from ...mock_server import AIOServer
6+
from .. import ClientHTTPStubber
7+
8+
9+
def get_captured_ua_strings(stubber):
10+
"""Get captured request-level user agent strings from stubber.
11+
:type stubber: tests.BaseHTTPStubber
12+
"""
13+
return [req.headers['User-Agent'].decode() for req in stubber.requests]
14+
15+
16+
def parse_registered_feature_ids(ua_string):
17+
"""Parse registered feature ids in user agent string.
18+
:type ua_string: str
19+
:rtype: list[str]
20+
"""
21+
ua_fields = ua_string.split(' ')
22+
feature_field = [field for field in ua_fields if field.startswith('m/')][0]
23+
return feature_field[2:].split(',')
24+
25+
26+
async def test_user_agent_has_registered_feature_id():
27+
session = AioSession()
28+
29+
async with (
30+
AIOServer() as server,
31+
session.create_client(
32+
's3',
33+
endpoint_url=server.endpoint_url,
34+
aws_secret_access_key='xxx',
35+
aws_access_key_id='xxx',
36+
) as s3_client,
37+
):
38+
with ClientHTTPStubber(s3_client) as stub_client:
39+
stub_client.add_response()
40+
paginator = s3_client.get_paginator('list_buckets')
41+
# The `paginate()` method registers `'PAGINATOR': 'C'`
42+
async for _ in paginator.paginate():
43+
pass
44+
45+
ua_string = get_captured_ua_strings(stub_client)[0]
46+
feature_list = parse_registered_feature_ids(ua_string)
47+
assert 'C' in feature_list
48+
49+
50+
async def test_registered_feature_ids_dont_bleed_between_requests():
51+
session = AioSession()
52+
53+
async with (
54+
AIOServer() as server,
55+
session.create_client(
56+
's3',
57+
endpoint_url=server.endpoint_url,
58+
aws_secret_access_key='xxx',
59+
aws_access_key_id='xxx',
60+
) as s3_client,
61+
):
62+
with ClientHTTPStubber(s3_client) as stub_client:
63+
stub_client.add_response()
64+
waiter = s3_client.get_waiter('bucket_exists')
65+
# The `wait()` method registers `'WAITER': 'B'`
66+
await waiter.wait(Bucket='mybucket')
67+
68+
stub_client.add_response()
69+
paginator = s3_client.get_paginator('list_buckets')
70+
# The `paginate()` method registers `'PAGINATOR': 'C'`
71+
async for _ in paginator.paginate():
72+
pass
73+
74+
ua_strings = get_captured_ua_strings(stub_client)
75+
waiter_feature_list = parse_registered_feature_ids(ua_strings[0])
76+
assert 'B' in waiter_feature_list
77+
78+
paginator_feature_list = parse_registered_feature_ids(ua_strings[1])
79+
assert 'C' in paginator_feature_list
80+
assert 'B' not in paginator_feature_list
81+
82+
83+
# This tests context's bleeding across tasks instead
84+
async def test_registered_feature_ids_dont_bleed_across_threads():
85+
session = AioSession()
86+
87+
async with (
88+
AIOServer() as server,
89+
session.create_client(
90+
's3',
91+
endpoint_url=server.endpoint_url,
92+
aws_secret_access_key='xxx',
93+
aws_access_key_id='xxx',
94+
) as s3_client,
95+
):
96+
97+
async def wait():
98+
with ClientHTTPStubber(s3_client) as stub_client:
99+
stub_client.add_response()
100+
waiter = s3_client.get_waiter('bucket_exists')
101+
# The `wait()` method registers `'WAITER': 'B'`
102+
await waiter.wait(Bucket='mybucket')
103+
ua_string = get_captured_ua_strings(stub_client)[0]
104+
return parse_registered_feature_ids(ua_string)
105+
106+
async def paginate():
107+
with ClientHTTPStubber(s3_client) as stub_client:
108+
stub_client.add_response()
109+
paginator = s3_client.get_paginator('list_buckets')
110+
# The `paginate()` method registers `'PAGINATOR': 'C'`
111+
async for _ in paginator.paginate():
112+
pass
113+
ua_string = get_captured_ua_strings(stub_client)[0]
114+
return parse_registered_feature_ids(ua_string)
115+
116+
waiter_features, paginator_features = await asyncio.gather(
117+
wait(), paginate()
118+
)
119+
120+
assert 'B' in waiter_features
121+
assert 'C' not in waiter_features
122+
assert 'C' in paginator_features
123+
assert 'B' not in paginator_features

tests/test_patches.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,12 @@ def test_protocol_parsers():
786786
'a7e83728338e61ff2ca0a26c6f03c67cbabffc32',
787787
},
788788
),
789+
(
790+
PageIterator._make_request,
791+
{
792+
'e926671018897ac5851a3add5d2bc15a2d6142df',
793+
},
794+
),
789795
(
790796
PageIterator.result_key_iters,
791797
{

0 commit comments

Comments
 (0)