1
1
from __future__ import annotations
2
2
3
3
import json
4
+ import time
4
5
import urllib .request
5
- from functools import lru_cache
6
6
from ssl import SSLContext
7
7
from typing import Any , Dict , List , Optional
8
8
from urllib .error import URLError
@@ -44,12 +44,15 @@ def __init__(
44
44
else :
45
45
self .jwk_set_cache = None
46
46
47
+ # Replace lru_cache with TTL-aware individual key cache
48
+ # Use the same TTL as JWKSetCache for consistency
47
49
if cache_keys :
48
- # Cache signing keys
49
- # Ignore mypy (https://github.com/python/mypy/issues/2427)
50
- self .get_signing_key = lru_cache (maxsize = max_cached_keys )(
51
- self .get_signing_key
52
- ) # type: ignore
50
+ self ._key_cache_enabled = True
51
+ self ._key_cache : Dict [str , tuple [PyJWK , float ]] = {} # kid -> (key, timestamp)
52
+ self ._max_cached_keys = max_cached_keys
53
+ self ._key_cache_ttl = lifespan # Use same TTL as JWKSetCache
54
+ else :
55
+ self ._key_cache_enabled = False
53
56
54
57
def fetch_data (self ) -> Any :
55
58
jwk_set : Any = None
@@ -95,12 +98,44 @@ def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
95
98
96
99
return signing_keys
97
100
101
+ def _get_cached_key (self , kid : str ) -> Optional [PyJWK ]:
102
+ """Get a cached key if it exists and hasn't expired."""
103
+ if not self ._key_cache_enabled or kid not in self ._key_cache :
104
+ return None
105
+
106
+ key , timestamp = self ._key_cache [kid ]
107
+
108
+ # Check and remove if expired (use same logic as JWKSetCache)
109
+ if time .monotonic () - timestamp > self ._key_cache_ttl :
110
+ del self ._key_cache [kid ]
111
+ return None
112
+
113
+ return key
114
+
115
+ def _cache_key (self , kid : str , key : PyJWK ) -> None :
116
+ """Cache a key with current timestamp."""
117
+ if not self ._key_cache_enabled :
118
+ return
119
+
120
+ # Evict oldest if at capacity
121
+ if len (self ._key_cache ) >= self ._max_cached_keys and kid not in self ._key_cache :
122
+ # Simple eviction: remove oldest timestamp
123
+ oldest_kid = min (self ._key_cache .keys (),
124
+ key = lambda k : self ._key_cache [k ][1 ])
125
+ del self ._key_cache [oldest_kid ]
126
+
127
+ self ._key_cache [kid ] = (key , time .monotonic ())
128
+
98
129
def get_signing_key (self , kid : str ) -> PyJWK :
130
+ # Check TTL-aware cache first
131
+ cached_key = self ._get_cached_key (kid )
132
+ if cached_key is not None :
133
+ return cached_key
134
+
99
135
signing_keys = self .get_signing_keys ()
100
136
signing_key = self .match_kid (signing_keys , kid )
101
137
102
138
if not signing_key :
103
- # If no matching signing key from the jwk set, refresh the jwk set and try again.
104
139
signing_keys = self .get_signing_keys (refresh = True )
105
140
signing_key = self .match_kid (signing_keys , kid )
106
141
@@ -109,6 +144,8 @@ def get_signing_key(self, kid: str) -> PyJWK:
109
144
f'Unable to find a signing key that matches: "{ kid } "'
110
145
)
111
146
147
+ # Cache the key with TTL (not lru)
148
+ self ._cache_key (kid , signing_key )
112
149
return signing_key
113
150
114
151
def get_signing_key_from_jwt (self , token : str | bytes ) -> PyJWK :
0 commit comments