Skip to content

Commit 11b357a

Browse files
Fix issue #1051: Replace lru_cache with TTL-based key caching
Replace functools.lru_cache with TTL-aware cache to prevent indefinite caching of potentially revoked signing keys. - Remove lru_cache which cached keys forever - Implement TTL-based individual key caching - Keys now expire after configured lifespan - Maintains backward compatibility with existing API - Add test demonstrating fix works Fixes #1051
1 parent 3942ec3 commit 11b357a

File tree

2 files changed

+97
-7
lines changed

2 files changed

+97
-7
lines changed

jwt/jwks_client.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
import json
4+
import time
45
import urllib.request
5-
from functools import lru_cache
66
from ssl import SSLContext
77
from typing import Any, Dict, List, Optional
88
from urllib.error import URLError
@@ -44,12 +44,15 @@ def __init__(
4444
else:
4545
self.jwk_set_cache = None
4646

47+
# Replace lru_cache with TTL-aware individual key cache
48+
# Use the same TTL as JWKSetCache for consistency
4749
if cache_keys:
48-
# Cache signing keys
49-
# Ignore mypy (https://github.com/python/mypy/issues/2427)
50-
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(
51-
self.get_signing_key
52-
) # type: ignore
50+
self._key_cache_enabled = True
51+
self._key_cache: Dict[str, tuple[PyJWK, float]] = {} # kid -> (key, timestamp)
52+
self._max_cached_keys = max_cached_keys
53+
self._key_cache_ttl = lifespan # Use same TTL as JWKSetCache
54+
else:
55+
self._key_cache_enabled = False
5356

5457
def fetch_data(self) -> Any:
5558
jwk_set: Any = None
@@ -95,12 +98,44 @@ def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
9598

9699
return signing_keys
97100

101+
def _get_cached_key(self, kid: str) -> Optional[PyJWK]:
102+
"""Get a cached key if it exists and hasn't expired."""
103+
if not self._key_cache_enabled or kid not in self._key_cache:
104+
return None
105+
106+
key, timestamp = self._key_cache[kid]
107+
108+
# Check and remove if expired (use same logic as JWKSetCache)
109+
if time.monotonic() - timestamp > self._key_cache_ttl:
110+
del self._key_cache[kid]
111+
return None
112+
113+
return key
114+
115+
def _cache_key(self, kid: str, key: PyJWK) -> None:
116+
"""Cache a key with current timestamp."""
117+
if not self._key_cache_enabled:
118+
return
119+
120+
# Evict oldest if at capacity
121+
if len(self._key_cache) >= self._max_cached_keys and kid not in self._key_cache:
122+
# Simple eviction: remove oldest timestamp
123+
oldest_kid = min(self._key_cache.keys(),
124+
key=lambda k: self._key_cache[k][1])
125+
del self._key_cache[oldest_kid]
126+
127+
self._key_cache[kid] = (key, time.monotonic())
128+
98129
def get_signing_key(self, kid: str) -> PyJWK:
130+
# Check TTL-aware cache first
131+
cached_key = self._get_cached_key(kid)
132+
if cached_key is not None:
133+
return cached_key
134+
99135
signing_keys = self.get_signing_keys()
100136
signing_key = self.match_kid(signing_keys, kid)
101137

102138
if not signing_key:
103-
# If no matching signing key from the jwk set, refresh the jwk set and try again.
104139
signing_keys = self.get_signing_keys(refresh=True)
105140
signing_key = self.match_kid(signing_keys, kid)
106141

@@ -109,6 +144,8 @@ def get_signing_key(self, kid: str) -> PyJWK:
109144
f'Unable to find a signing key that matches: "{kid}"'
110145
)
111146

147+
# Cache the key with TTL (not lru)
148+
self._cache_key(kid, signing_key)
112149
return signing_key
113150

114151
def get_signing_key_from_jwt(self, token: str | bytes) -> PyJWK:

tests/test_jwks_client.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,56 @@ def test_get_jwt_set_sslcontext_no_ca(self):
355355
jwks_client.get_jwk_set()
356356

357357
assert "Failed to get an expected error"
358+
359+
def test_security_fix_revoked_keys_expire(self):
360+
"""
361+
Test that demonstrates the security fix working.
362+
363+
This test should:
364+
- FAIL with old lru_cache implementation (serves revoked keys forever)
365+
- PASS with new TTL implementation (revoked keys expire)
366+
"""
367+
import time
368+
from unittest.mock import patch
369+
370+
client = PyJWKClient("https://example.com", cache_keys=True, lifespan=0.1)
371+
372+
# Use the real RSA key from existing tests
373+
real_rsa_key = {
374+
"kid": "revoked-key-123",
375+
"kty": "RSA",
376+
"use": "sig",
377+
"n": "0wtlJRY9-ru61LmOgieeI7_rD1oIna9QpBMAOWw8wTuoIhFQFwcIi7MFB7IEfelCPj08vkfLsuFtR8cG07EE4uvJ78bAqRjMsCvprWp4e2p7hqPnWcpRpDEyHjzirEJle1LPpjLLVaSWgkbrVaOD0lkWkP1T1TkrOset_Obh8BwtO-Ww-UfrEwxTyz1646AGkbT2nL8PX0trXrmira8GnrCkFUgTUS61GoTdb9bCJ19PLX9Gnxw7J0BtR0GubopXq8KlI0ThVql6ZtVGN2dvmrCPAVAZleM5TVB61m0VSXvGWaF6_GeOhbFoyWcyUmFvzWhBm8Q38vWgsSI7oHTkEw",
378+
"e": "AQAB"
379+
}
380+
381+
different_key = {
382+
"kid": "different-key-456",
383+
"kty": "RSA",
384+
"use": "sig",
385+
"n": "39SJ39VgrQ0qMNK74CaueUBlyYsUyuA7yWlHYZ-jAj6tlFKugEVUTBUVbhGF44uOr99iL_cwmr-srqQDEi-jFHdkS6WFkYyZ03oyyx5dtBMtzrXPieFipSGfQ5EGUGloaKDjL-Ry9tiLnysH2VVWZ5WDDN-DGHxuCOWWjiBNcTmGfnj5_NvRHNUh2iTLuiJpHbGcPzWc5-lc4r-_ehw9EFfp2XsxE9xvtbMZ4SouJCiv9xnrnhe2bdpWuu34hXZCrQwE8DjRY3UR8LjyMxHHPLzX2LWNMHjfN3nAZMteS-Ok11VYDFI-4qCCVGo_WesBCAeqCjPLRyZoV27x1YGsUQ",
386+
"e": "AQAB"
387+
}
388+
389+
jwks_with_key = {"keys": [real_rsa_key]}
390+
jwks_key_revoked = {"keys": [different_key]} # Simulate original key is gone
391+
392+
with patch.object(client, 'fetch_data') as mock_fetch:
393+
# Step 1: Get key (it gets cached)
394+
mock_fetch.return_value = jwks_with_key
395+
key1 = client.get_signing_key("revoked-key-123")
396+
assert key1.key_id == "revoked-key-123"
397+
398+
# Step 2: Wait for cache to expire
399+
time.sleep(0.15) # Longer than lifespan
400+
401+
# Step 3: Key is now "revoked" (removed from JWKS)
402+
mock_fetch.return_value = jwks_key_revoked
403+
404+
# Step 4: THE SECURITY TEST
405+
# Expected behavior: Should raise exception when key is revoked
406+
# Old vulnerable code: Returns cached key, test fails
407+
# New secure code: Raises exception as expected, test passes
408+
with pytest.raises(PyJWKClientError, match="Unable to find a signing key"):
409+
client.get_signing_key("revoked-key-123")
410+

0 commit comments

Comments
 (0)