Skip to content

Commit b55f261

Browse files
authored
JWK .equals and .hashCode (#823)
* Adjusted JWK .equals implementations to only account for kty value and material fields (two JWKs are equal if their type and key material are equal, regardless of other public parameters and/or custom name/value pairs). * Adjusted JWK .hashCode implementation to pre-cache its value based on JwkThumpbrint fields since JWKs are immutable
1 parent f60d560 commit b55f261

24 files changed

+471
-37
lines changed

impl/src/main/java/io/jsonwebtoken/impl/lang/Bytes.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,16 @@ public static byte[] concat(byte[]... arrays) {
176176
return output;
177177
}
178178

179+
/**
180+
* Clears the array by filling it with all zeros. Does nothing with a null or empty argument.
181+
*
182+
* @param bytes the (possibly null or empty) byte array to clear
183+
*/
184+
public static void clear(byte[] bytes) {
185+
if (isEmpty(bytes)) return;
186+
java.util.Arrays.fill(bytes, (byte) 0);
187+
}
188+
179189
public static boolean isEmpty(byte[] bytes) {
180190
return length(bytes) == 0;
181191
}

impl/src/main/java/io/jsonwebtoken/impl/lang/Fields.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
import io.jsonwebtoken.lang.Arrays;
1919
import io.jsonwebtoken.lang.Assert;
20+
import io.jsonwebtoken.lang.Objects;
2021
import io.jsonwebtoken.lang.Registry;
2122

2223
import java.math.BigInteger;
2324
import java.net.URI;
25+
import java.security.MessageDigest;
2426
import java.security.cert.X509Certificate;
2527
import java.util.Collection;
2628
import java.util.Date;
@@ -97,4 +99,46 @@ public static Registry<String, Field<?>> registry(Registry<String, Field<?>> par
9799
newFields.put(id, field); // add new one
98100
return registry(newFields.values());
99101
}
102+
103+
private static byte[] bytes(BigInteger i) {
104+
return i != null ? i.toByteArray() : null;
105+
}
106+
107+
public static boolean bytesEquals(BigInteger a, BigInteger b) {
108+
//noinspection NumberEquality
109+
if (a == b) return true;
110+
if (a == null || b == null) return false;
111+
byte[] aBytes = bytes(a);
112+
byte[] bBytes = bytes(b);
113+
try {
114+
return MessageDigest.isEqual(aBytes, bBytes);
115+
} finally {
116+
Bytes.clear(aBytes);
117+
Bytes.clear(bBytes);
118+
}
119+
}
120+
121+
private static <T> boolean equals(T a, T b, Field<T> field) {
122+
if (a == b) return true;
123+
if (a == null || b == null) return false;
124+
if (field.isSecret()) {
125+
// byte[] and BigInteger are the only types of secret Fields in the JJWT codebase
126+
// (i.e. Field.isSecret() == true). If a Field is ever marked as secret, and it's not one of these two
127+
// data types, we need to know about it. So we use the 'assertSecret' helper above to ensure we do:
128+
if (a instanceof byte[]) {
129+
return b instanceof byte[] && MessageDigest.isEqual((byte[]) a, (byte[]) b);
130+
} else if (a instanceof BigInteger) {
131+
return b instanceof BigInteger && bytesEquals((BigInteger) a, (BigInteger) b);
132+
}
133+
}
134+
// default to a standard null-safe comparison:
135+
return Objects.nullSafeEquals(a, b);
136+
}
137+
138+
public static <T> boolean equals(FieldReadable a, Object o, Field<T> field) {
139+
if (a == o) return true;
140+
if (a == null || !(o instanceof FieldReadable)) return false;
141+
FieldReadable b = (FieldReadable) o;
142+
return equals(a.get(field), b.get(field), field);
143+
}
100144
}

impl/src/main/java/io/jsonwebtoken/impl/security/AbstractJwk.java

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333

3434
import java.nio.charset.StandardCharsets;
3535
import java.security.Key;
36+
import java.security.PrivateKey;
37+
import java.security.PublicKey;
38+
import java.util.ArrayList;
3639
import java.util.Collection;
3740
import java.util.Iterator;
3841
import java.util.List;
@@ -48,10 +51,11 @@ public abstract class AbstractJwk<K extends Key> implements Jwk<K>, FieldReadabl
4851
.set().setId("key_ops").setName("Key Operations").build();
4952
static final Field<String> KTY = Fields.string("kty", "Key Type");
5053
static final Set<Field<?>> FIELDS = Collections.setOf(ALG, KID, KEY_OPS, KTY);
51-
5254
public static final String IMMUTABLE_MSG = "JWKs are immutable and may not be modified.";
55+
5356
protected final JwkContext<K> context;
5457
private final List<Field<?>> THUMBPRINT_FIELDS;
58+
private final int hashCode;
5559

5660
/**
5761
* @param ctx the backing JwkContext containing the JWK field values.
@@ -71,6 +75,40 @@ public abstract class AbstractJwk<K extends Key> implements Jwk<K>, FieldReadabl
7175
String kid = thumbprint.toString();
7276
ctx.setId(kid);
7377
}
78+
this.hashCode = computeHashCode();
79+
}
80+
81+
/**
82+
* Compute and return the JWK hashCode. As JWKs are immutable, this value will be cached as a final constant
83+
* upon JWK instantiation. This uses the JWK's thumbprint fields during computation, but differs from JwkThumbprint
84+
* calculation in two ways:
85+
* <ol>
86+
* <li>JwkThumbprints use a MessageDigest calculation, which is unnecessary overhead for a hashcode</li>
87+
* <li>The hashCode calculation uses each field's idiomatic (Java) object value instead of the
88+
* JwkThumbprint-required canonical (String) value.</li>
89+
* </ol>
90+
*
91+
* @return the JWK hashcode
92+
*/
93+
private int computeHashCode() {
94+
List<Object> list = new ArrayList<>(this.THUMBPRINT_FIELDS.size() + 1 /* possible discriminator */);
95+
// So we don't leak information about the private key value, we need a discriminator to ensure that
96+
// public and private key hashCodes are not identical (in case both JWKs need to be in the same hash set).
97+
// So we add a discriminator String to the list of values that are used during hashCode calculation
98+
Key key = Assert.notNull(toKey(), "JWK toKey() value cannot be null.");
99+
if (key instanceof PublicKey) {
100+
list.add("Public");
101+
} else if (key instanceof PrivateKey) {
102+
list.add("Private");
103+
}
104+
for (Field<?> field : this.THUMBPRINT_FIELDS) {
105+
// Unlike thumbprint calculation, we get the idiomatic (Java) value, not canonical (String) value
106+
// (We could have used either actually, but the idiomatic value hashCode calculation is probably
107+
// faster).
108+
Object val = Assert.notNull(get(field), "computeHashCode: Field idiomatic value cannot be null.");
109+
list.add(val);
110+
}
111+
return Objects.nullSafeHashCode(list.toArray());
74112
}
75113

76114
private String getRequiredThumbprintValue(Field<?> field) {
@@ -230,13 +268,20 @@ public String toString() {
230268
}
231269

232270
@Override
233-
public int hashCode() {
234-
return this.context.hashCode();
271+
public final int hashCode() {
272+
return this.hashCode;
235273
}
236274

237-
@SuppressWarnings("EqualsWhichDoesntCheckParameterClass")
238275
@Override
239-
public boolean equals(Object obj) {
240-
return this.context.equals(obj);
276+
public final boolean equals(Object obj) {
277+
if (obj == this) return true;
278+
if (obj instanceof Jwk<?>) {
279+
Jwk<?> other = (Jwk<?>) obj;
280+
// this.getType() guaranteed non-null in constructor:
281+
return getType().equals(other.getType()) && equals(other);
282+
}
283+
return false;
241284
}
285+
286+
protected abstract boolean equals(Jwk<?> jwk);
242287
}

impl/src/main/java/io/jsonwebtoken/impl/security/AbstractJwkBuilder.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ static class DefaultSecretJwkBuilder extends AbstractJwkBuilder<SecretKey, Secre
170170
implements SecretJwkBuilder {
171171
public DefaultSecretJwkBuilder(JwkContext<SecretKey> ctx) {
172172
super(ctx);
173+
// assign a standard algorithm if possible:
174+
Key key = Assert.notNull(ctx.getKey(), "SecretKey cannot be null.");
175+
DefaultMacAlgorithm mac = DefaultMacAlgorithm.findByKey(key);
176+
if (mac != null) {
177+
algorithm(mac.getId());
178+
}
173179
}
174180
}
175181
}

impl/src/main/java/io/jsonwebtoken/impl/security/AbstractPrivateJwk.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import io.jsonwebtoken.impl.lang.Field;
1919
import io.jsonwebtoken.lang.Assert;
20+
import io.jsonwebtoken.security.Jwk;
2021
import io.jsonwebtoken.security.KeyPair;
2122
import io.jsonwebtoken.security.PrivateJwk;
2223
import io.jsonwebtoken.security.PublicJwk;
@@ -47,4 +48,11 @@ public M toPublicJwk() {
4748
public KeyPair<L, K> toKeyPair() {
4849
return this.keyPair;
4950
}
51+
52+
@Override
53+
protected final boolean equals(Jwk<?> jwk) {
54+
return jwk instanceof PrivateJwk && equals((PrivateJwk<?, ?, ?>) jwk);
55+
}
56+
57+
protected abstract boolean equals(PrivateJwk<?, ?, ?> jwk);
5058
}

impl/src/main/java/io/jsonwebtoken/impl/security/AbstractPublicJwk.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package io.jsonwebtoken.impl.security;
1717

1818
import io.jsonwebtoken.impl.lang.Field;
19+
import io.jsonwebtoken.security.Jwk;
1920
import io.jsonwebtoken.security.PublicJwk;
2021

2122
import java.security.PublicKey;
@@ -25,4 +26,11 @@ abstract class AbstractPublicJwk<K extends PublicKey> extends AbstractAsymmetric
2526
AbstractPublicJwk(JwkContext<K> ctx, List<Field<?>> thumbprintFields) {
2627
super(ctx, thumbprintFields);
2728
}
29+
30+
@Override
31+
protected final boolean equals(Jwk<?> jwk) {
32+
return jwk instanceof PublicJwk && equals((PublicJwk<?>) jwk);
33+
}
34+
35+
protected abstract boolean equals(PublicJwk<?> jwk);
2836
}

impl/src/main/java/io/jsonwebtoken/impl/security/DefaultEcPrivateJwk.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
import io.jsonwebtoken.lang.Collections;
2121
import io.jsonwebtoken.security.EcPrivateJwk;
2222
import io.jsonwebtoken.security.EcPublicJwk;
23+
import io.jsonwebtoken.security.PrivateJwk;
2324

2425
import java.math.BigInteger;
2526
import java.security.interfaces.ECPrivateKey;
2627
import java.security.interfaces.ECPublicKey;
2728
import java.util.Set;
2829

30+
import static io.jsonwebtoken.impl.security.DefaultEcPublicJwk.equalsPublic;
31+
2932
class DefaultEcPrivateJwk extends AbstractPrivateJwk<ECPrivateKey, ECPublicKey, EcPublicJwk> implements EcPrivateJwk {
3033

3134
static final Field<BigInteger> D = Fields.secretBigInt("d", "ECC Private Key");
@@ -38,4 +41,9 @@ class DefaultEcPrivateJwk extends AbstractPrivateJwk<ECPrivateKey, ECPublicKey,
3841
DefaultEcPublicJwk.THUMBPRINT_FIELDS,
3942
pubJwk);
4043
}
44+
45+
@Override
46+
protected boolean equals(PrivateJwk<?, ?, ?> jwk) {
47+
return jwk instanceof EcPrivateJwk && equalsPublic(this, jwk) && Fields.equals(this, jwk, D);
48+
}
4149
}

impl/src/main/java/io/jsonwebtoken/impl/security/DefaultEcPublicJwk.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
package io.jsonwebtoken.impl.security;
1717

1818
import io.jsonwebtoken.impl.lang.Field;
19+
import io.jsonwebtoken.impl.lang.FieldReadable;
1920
import io.jsonwebtoken.impl.lang.Fields;
2021
import io.jsonwebtoken.lang.Collections;
2122
import io.jsonwebtoken.security.EcPublicJwk;
23+
import io.jsonwebtoken.security.PublicJwk;
2224

2325
import java.math.BigInteger;
2426
import java.security.interfaces.ECPublicKey;
@@ -39,4 +41,15 @@ class DefaultEcPublicJwk extends AbstractPublicJwk<ECPublicKey> implements EcPub
3941
DefaultEcPublicJwk(JwkContext<ECPublicKey> ctx) {
4042
super(ctx, THUMBPRINT_FIELDS);
4143
}
44+
45+
static boolean equalsPublic(FieldReadable self, Object candidate) {
46+
return Fields.equals(self, candidate, CRV) &&
47+
Fields.equals(self, candidate, X) &&
48+
Fields.equals(self, candidate, Y);
49+
}
50+
51+
@Override
52+
protected boolean equals(PublicJwk<?> jwk) {
53+
return jwk instanceof EcPublicJwk && equalsPublic(this, jwk);
54+
}
4255
}

impl/src/main/java/io/jsonwebtoken/impl/security/DefaultMacAlgorithm.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ final class DefaultMacAlgorithm extends AbstractSecureDigestAlgorithm<SecretKey,
5353
static final DefaultMacAlgorithm HS384 = new DefaultMacAlgorithm(384);
5454
static final DefaultMacAlgorithm HS512 = new DefaultMacAlgorithm(512);
5555

56-
private static final Map<String, MacAlgorithm> JCA_NAME_MAP;
56+
private static final Map<String, DefaultMacAlgorithm> JCA_NAME_MAP;
5757

5858
static {
5959
JCA_NAME_MAP = new LinkedHashMap<>(6);
@@ -96,15 +96,15 @@ private static boolean isJwaStandardJcaName(String jcaName) {
9696
return JCA_NAME_MAP.containsKey(key);
9797
}
9898

99-
static MacAlgorithm findByKey(Key key) {
99+
static DefaultMacAlgorithm findByKey(Key key) {
100100

101101
String alg = KeysBridge.findAlgorithm(key);
102102
if (!Strings.hasText(alg)) {
103103
return null;
104104
}
105105

106106
String upper = alg.toUpperCase(Locale.ENGLISH);
107-
MacAlgorithm mac = JCA_NAME_MAP.get(upper);
107+
DefaultMacAlgorithm mac = JCA_NAME_MAP.get(upper);
108108
if (mac == null) {
109109
return null;
110110
}

impl/src/main/java/io/jsonwebtoken/impl/security/DefaultOctetPrivateJwk.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,16 @@
2020
import io.jsonwebtoken.lang.Collections;
2121
import io.jsonwebtoken.security.OctetPrivateJwk;
2222
import io.jsonwebtoken.security.OctetPublicJwk;
23+
import io.jsonwebtoken.security.PrivateJwk;
2324

2425
import java.security.PrivateKey;
2526
import java.security.PublicKey;
2627
import java.util.Set;
2728

28-
public class DefaultOctetPrivateJwk<T extends PrivateKey, P extends PublicKey> extends AbstractPrivateJwk<T, P, OctetPublicJwk<P>> implements OctetPrivateJwk<T, P> {
29+
import static io.jsonwebtoken.impl.security.DefaultOctetPublicJwk.equalsPublic;
30+
31+
public class DefaultOctetPrivateJwk<T extends PrivateKey, P extends PublicKey>
32+
extends AbstractPrivateJwk<T, P, OctetPublicJwk<P>> implements OctetPrivateJwk<T, P> {
2933

3034
static final Field<byte[]> D = Fields.bytes("d", "The private key").setSecret(true).build();
3135

@@ -37,4 +41,9 @@ public class DefaultOctetPrivateJwk<T extends PrivateKey, P extends PublicKey> e
3741
// https://www.rfc-editor.org/rfc/rfc7638#section-3.2.1
3842
DefaultOctetPublicJwk.THUMBPRINT_FIELDS, pubJwk);
3943
}
44+
45+
@Override
46+
protected boolean equals(PrivateJwk<?, ?, ?> jwk) {
47+
return jwk instanceof OctetPrivateJwk && equalsPublic(this, jwk) && Fields.equals(this, jwk, D);
48+
}
4049
}

impl/src/main/java/io/jsonwebtoken/impl/security/DefaultOctetPublicJwk.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
package io.jsonwebtoken.impl.security;
1717

1818
import io.jsonwebtoken.impl.lang.Field;
19+
import io.jsonwebtoken.impl.lang.FieldReadable;
1920
import io.jsonwebtoken.impl.lang.Fields;
2021
import io.jsonwebtoken.lang.Collections;
2122
import io.jsonwebtoken.security.OctetPublicJwk;
23+
import io.jsonwebtoken.security.PublicJwk;
2224

2325
import java.security.PublicKey;
2426
import java.util.List;
@@ -37,4 +39,13 @@ public class DefaultOctetPublicJwk<T extends PublicKey> extends AbstractPublicJw
3739
DefaultOctetPublicJwk(JwkContext<T> ctx) {
3840
super(ctx, THUMBPRINT_FIELDS);
3941
}
42+
43+
static boolean equalsPublic(FieldReadable self, Object candidate) {
44+
return Fields.equals(self, candidate, CRV) && Fields.equals(self, candidate, X);
45+
}
46+
47+
@Override
48+
protected boolean equals(PublicJwk<?> jwk) {
49+
return jwk instanceof OctetPublicJwk && equalsPublic(this, jwk);
50+
}
4051
}

0 commit comments

Comments
 (0)