Skip to content

Commit dde4579

Browse files
committed
Merge pull request #49 from defyrlt/feature/custom_json_encoder
Custom JSON Encoder
2 parents 139a779 + 842f0d1 commit dde4579

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

jwt/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def header(jwt):
215215
raise DecodeError('Invalid header encoding')
216216

217217

218-
def encode(payload, key, algorithm='HS256', headers=None):
218+
def encode(payload, key, algorithm='HS256', headers=None, json_encoder=None):
219219
segments = []
220220

221221
if algorithm is None:
@@ -231,7 +231,9 @@ def encode(payload, key, algorithm='HS256', headers=None):
231231
if headers:
232232
header.update(headers)
233233

234-
json_header = json.dumps(header, separators=(',', ':')).encode('utf-8')
234+
json_header = json.dumps(header,
235+
separators=(',', ':'),
236+
cls=json_encoder).encode('utf-8')
235237
segments.append(base64url_encode(json_header))
236238

237239
# Payload
@@ -240,7 +242,9 @@ def encode(payload, key, algorithm='HS256', headers=None):
240242
if isinstance(payload.get(time_claim), datetime):
241243
payload[time_claim] = timegm(payload[time_claim].utctimetuple())
242244

243-
json_payload = json.dumps(payload, separators=(',', ':')).encode('utf-8')
245+
json_payload = json.dumps(payload,
246+
separators=(',', ':'),
247+
cls=json_encoder).encode('utf-8')
244248
segments.append(base64url_encode(json_payload))
245249

246250
# Segments

tests/test_jwt.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from __future__ import unicode_literals
2+
23
from calendar import timegm
34
from datetime import datetime
5+
from decimal import Decimal
6+
47
import sys
58
import time
69
import unittest
10+
import json
711

812
import jwt
913

@@ -811,6 +815,27 @@ def test_raise_exception_token_without_issuer(self):
811815
jwt.InvalidIssuer,
812816
lambda: jwt.decode(token, 'secret', issuer=issuer))
813817

818+
def test_custom_json_encoder(self):
819+
820+
class CustomJSONEncoder(json.JSONEncoder):
821+
822+
def default(self, o):
823+
if isinstance(o, Decimal):
824+
return 'it worked'
825+
return super(CustomJSONEncoder, self).default(o)
826+
827+
data = {
828+
'some_decimal': Decimal('2.2')
829+
}
830+
831+
self.assertRaises(
832+
TypeError,
833+
lambda: jwt.encode(data, 'secret'))
834+
835+
token = jwt.encode(data, 'secret', json_encoder=CustomJSONEncoder)
836+
payload = jwt.decode(token, 'secret')
837+
self.assertEqual(payload, {'some_decimal': 'it worked'})
838+
814839

815840
if __name__ == '__main__':
816841
unittest.main()

0 commit comments

Comments
 (0)