Skip to content

Fix #1051: Replace lru_cache with TTL-based key caching #1070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This project adheres to `Semantic Versioning <https://semver.org/>`__.

Fixed
~~~~~
- Fix indefinite key caching in PyJWKClient by replacing lru_cache with TTL-aware cache in `#1070 <https://github.com/jpadilla/pyjwt/pull/1070>`__
- Validate key against allowed types for Algorithm family in `#964 <https://github.com/jpadilla/pyjwt/pull/964>`__
- Add iterator for JWKSet in `#1041 <https://github.com/jpadilla/pyjwt/pull/1041>`__
- Validate `iss` claim is a string during encoding and decoding by @pachewise in `#1040 <https://github.com/jpadilla/pyjwt/pull/1040>`__
Expand Down
54 changes: 47 additions & 7 deletions jwt/jwks_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import json
import time
import urllib.request
from functools import lru_cache
from ssl import SSLContext
from typing import Any, Dict, List, Optional
from urllib.error import URLError
Expand Down Expand Up @@ -44,12 +44,17 @@ def __init__(
else:
self.jwk_set_cache = None

# Replace lru_cache with TTL-aware individual key cache
# Use the same TTL as JWKSetCache for consistency
if cache_keys:
# Cache signing keys
# Ignore mypy (https://github.com/python/mypy/issues/2427)
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(
self.get_signing_key
) # type: ignore
self._key_cache_enabled = True
self._key_cache: Dict[
str, tuple[PyJWK, float]
] = {} # kid -> (key, timestamp)
self._max_cached_keys = max_cached_keys
self._key_cache_ttl = lifespan # Use same TTL as JWKSetCache
else:
self._key_cache_enabled = False

def fetch_data(self) -> Any:
jwk_set: Any = None
Expand Down Expand Up @@ -95,12 +100,45 @@ def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:

return signing_keys

def _get_cached_key(self, kid: str) -> Optional[PyJWK]:
"""Get a cached key if it exists and hasn't expired."""
if not self._key_cache_enabled or kid not in self._key_cache:
return None

key, timestamp = self._key_cache[kid]

# Check and remove if expired (use same logic as JWKSetCache)
if time.monotonic() - timestamp > self._key_cache_ttl:
del self._key_cache[kid]
return None

return key

def _cache_key(self, kid: str, key: PyJWK) -> None:
"""Cache a key with current timestamp."""
if not self._key_cache_enabled:
return

# Evict oldest if at capacity
if len(self._key_cache) >= self._max_cached_keys and kid not in self._key_cache:
# Simple eviction: remove oldest timestamp
oldest_kid = min(
self._key_cache.keys(), key=lambda k: self._key_cache[k][1]
)
del self._key_cache[oldest_kid]

self._key_cache[kid] = (key, time.monotonic())

def get_signing_key(self, kid: str) -> PyJWK:
# Check TTL-aware cache first
cached_key = self._get_cached_key(kid)
if cached_key is not None:
return cached_key

signing_keys = self.get_signing_keys()
signing_key = self.match_kid(signing_keys, kid)

if not signing_key:
# If no matching signing key from the jwk set, refresh the jwk set and try again.
signing_keys = self.get_signing_keys(refresh=True)
signing_key = self.match_kid(signing_keys, kid)

Expand All @@ -109,6 +147,8 @@ def get_signing_key(self, kid: str) -> PyJWK:
f'Unable to find a signing key that matches: "{kid}"'
)

# Cache the key with TTL (not lru)
self._cache_key(kid, signing_key)
return signing_key

def get_signing_key_from_jwt(self, token: str | bytes) -> PyJWK:
Expand Down
78 changes: 78 additions & 0 deletions tests/test_jwks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,81 @@ def test_get_jwt_set_sslcontext_no_ca(self):
jwks_client.get_jwk_set()

assert "Failed to get an expected error"

def test_security_fix_revoked_keys_expire(self):
"""
Test that demonstrates the security fix working.

This test should:
- FAIL with old lru_cache implementation (serves revoked keys forever)
- PASS with new TTL implementation (revoked keys expire)
"""
import time
from unittest.mock import patch

client = PyJWKClient("https://example.com", cache_keys=True, lifespan=0.1)

# Use the real RSA key from existing tests
real_rsa_key = {
"kid": "revoked-key-123",
"kty": "RSA",
"use": "sig",
"n": "0wtlJRY9-ru61LmOgieeI7_rD1oIna9QpBMAOWw8wTuoIhFQFwcIi7MFB7IEfelCPj08vkfLsuFtR8cG07EE4uvJ78bAqRjMsCvprWp4e2p7hqPnWcpRpDEyHjzirEJle1LPpjLLVaSWgkbrVaOD0lkWkP1T1TkrOset_Obh8BwtO-Ww-UfrEwxTyz1646AGkbT2nL8PX0trXrmira8GnrCkFUgTUS61GoTdb9bCJ19PLX9Gnxw7J0BtR0GubopXq8KlI0ThVql6ZtVGN2dvmrCPAVAZleM5TVB61m0VSXvGWaF6_GeOhbFoyWcyUmFvzWhBm8Q38vWgsSI7oHTkEw",
"e": "AQAB",
}

different_key = {
"kid": "different-key-456",
"kty": "RSA",
"use": "sig",
"n": "39SJ39VgrQ0qMNK74CaueUBlyYsUyuA7yWlHYZ-jAj6tlFKugEVUTBUVbhGF44uOr99iL_cwmr-srqQDEi-jFHdkS6WFkYyZ03oyyx5dtBMtzrXPieFipSGfQ5EGUGloaKDjL-Ry9tiLnysH2VVWZ5WDDN-DGHxuCOWWjiBNcTmGfnj5_NvRHNUh2iTLuiJpHbGcPzWc5-lc4r-_ehw9EFfp2XsxE9xvtbMZ4SouJCiv9xnrnhe2bdpWuu34hXZCrQwE8DjRY3UR8LjyMxHHPLzX2LWNMHjfN3nAZMteS-Ok11VYDFI-4qCCVGo_WesBCAeqCjPLRyZoV27x1YGsUQ",
"e": "AQAB",
}

jwks_with_key = {"keys": [real_rsa_key]}
jwks_key_revoked = {"keys": [different_key]} # Simulate original key is gone

with patch.object(client, "fetch_data") as mock_fetch:
# Step 1: Get key (it gets cached)
mock_fetch.return_value = jwks_with_key
key1 = client.get_signing_key("revoked-key-123")
assert key1.key_id == "revoked-key-123"

# Step 2: Wait for cache to expire
time.sleep(0.15) # Longer than lifespan

# Step 3: Key is now "revoked" (removed from JWKS)
mock_fetch.return_value = jwks_key_revoked

# Step 4: THE SECURITY TEST
# Expected behavior: Should raise exception when key is revoked
# Old vulnerable code: Returns cached key, test fails
# New secure code: Raises exception as expected, test passes
with pytest.raises(PyJWKClientError, match="Unable to find a signing key"):
client.get_signing_key("revoked-key-123")

def test_key_cache_eviction_when_at_capacity(self):
"""Test that key cache evicts oldest entries when at capacity."""
from unittest.mock import MagicMock

client = PyJWKClient("https://example.com", cache_keys=True, max_cached_keys=2)

key1 = MagicMock()
key1.key_id = "key1"
key2 = MagicMock()
key2.key_id = "key2"
key3 = MagicMock()
key3.key_id = "key3"

# Fill cache to capacity
client._cache_key("key1", key1)
client._cache_key("key2", key2)
assert len(client._key_cache) == 2

# Add third key - should evict oldest (key1)
client._cache_key("key3", key3)
assert len(client._key_cache) == 2

assert "key1" not in client._key_cache
assert "key2" in client._key_cache
assert "key3" in client._key_cache