Skip to content

Added support for type parameters in the ParseXXX functions #271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ func ExampleParseWithClaims_customClaimsType() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
})

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
if claims := token.Claims; token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
} else {
fmt.Println(err)
Expand All @@ -103,11 +103,11 @@ func ExampleParseWithClaims_validationOptions() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
if claims := token.Claims; token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
} else {
fmt.Println(err)
Expand Down Expand Up @@ -136,11 +136,11 @@ func (m MyCustomClaims) CustomValidation() error {
func ExampleParseWithClaims_customValidation() {
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA"

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
if claims := token.Claims; token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
} else {
fmt.Println(err)
Expand All @@ -154,7 +154,7 @@ func ExampleParse_errorChecking() {
// Token from another example. This token is expired
var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"

token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
})

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/golang-jwt/jwt/v5

go 1.16
go 1.18
71 changes: 62 additions & 9 deletions map_claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ func TestVerifyAud(t *testing.T) {

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
var opts []ParserOption
var opts []ParserOption[MapClaims]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [staticcheck] <compile> reported by reviewdog 🐶
invalid operation: ParserOption[MapClaims] (ParserOption is not a generic type)


if test.Required {
opts = append(opts, WithAudience(test.Comparison))
opts = append(opts, WithAudience[MapClaims](test.Comparison))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [staticcheck] <compile> reported by reviewdog 🐶
invalid operation: cannot index WithAudience (value of type func(aud string) ParserOption)

}

validator := newValidator(opts...)
validator := newValidator[MapClaims](opts...)
got := validator.Validate(test.MapClaims)

if (got == nil) != test.Expected {
Expand All @@ -77,7 +77,7 @@ func TestMapclaimsVerifyIssuedAtInvalidTypeString(t *testing.T) {
"iat": "foo",
}
want := false
got := newValidator(WithIssuedAt()).Validate(mapClaims)
got := newValidator[MapClaims](WithIssuedAt[MapClaims]()).Validate(mapClaims)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [staticcheck] <compile> reported by reviewdog 🐶
invalid operation: cannot index WithIssuedAt (value of type func() ParserOption)

if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
Expand All @@ -88,7 +88,7 @@ func TestMapclaimsVerifyNotBeforeInvalidTypeString(t *testing.T) {
"nbf": "foo",
}
want := false
got := newValidator().Validate(mapClaims)
got := newValidator[MapClaims]().Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
Expand All @@ -99,7 +99,7 @@ func TestMapclaimsVerifyExpiresAtInvalidTypeString(t *testing.T) {
"exp": "foo",
}
want := false
got := newValidator().Validate(mapClaims)
got := newValidator[MapClaims]().Validate(mapClaims)

if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
Expand All @@ -112,25 +112,78 @@ func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) {
"exp": float64(exp.Unix()),
}
want := false
got := newValidator(WithTimeFunc(func() time.Time {
got := newValidator[MapClaims](WithTimeFunc[MapClaims](func() time.Time {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [staticcheck] <compile> reported by reviewdog 🐶
invalid operation: cannot index WithTimeFunc (value of type func(f func() time.Time) ParserOption)

return exp
})).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}

got = newValidator(WithTimeFunc(func() time.Time {
got = newValidator[MapClaims](WithTimeFunc[MapClaims](func() time.Time {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [staticcheck] <compile> reported by reviewdog 🐶
invalid operation: cannot index WithTimeFunc (value of type func(f func() time.Time) ParserOption)

return exp.Add(1 * time.Second)
})).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}

want = true
got = newValidator(WithTimeFunc(func() time.Time {
got = newValidator[MapClaims](WithTimeFunc[MapClaims](func() time.Time {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚫 [staticcheck] <compile> reported by reviewdog 🐶
invalid operation: cannot index WithTimeFunc (value of type func(f func() time.Time) ParserOption)

return exp.Add(-1 * time.Second)
})).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
}

func TestMapClaims_ParseString(t *testing.T) {
type args struct {
key string
}
tests := []struct {
name string
m MapClaims
args args
want string
wantErr bool
}{
{
name: "missing key",
m: MapClaims{},
args: args{
key: "mykey",
},
want: "",
wantErr: false,
},
{
name: "wrong key type",
m: MapClaims{"mykey": 4},
args: args{
key: "mykey",
},
want: "",
wantErr: true,
},
{
name: "correct key type",
m: MapClaims{"mykey": "mystring"},
args: args{
key: "mykey",
},
want: "mystring",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.m.ParseString(tt.args.key)
if (err != nil) != tt.wantErr {
t.Errorf("MapClaims.ParseString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("MapClaims.ParseString() = %v, want %v", got, tt.want)
}
})
}
}
28 changes: 14 additions & 14 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"strings"
)

type Parser struct {
type Parser[T Claims] struct {
// If populated, only these methods will be considered valid.
validMethods []string

Expand All @@ -21,24 +21,24 @@ type Parser struct {
}

// NewParser creates a new Parser with the specified options
func NewParser(options ...ParserOption) *Parser {
p := &Parser{
func NewParser[T Claims](options ...ParserOption) *Parser[T] {
p := &Parser[T]{
validator: &validator{},
}

// Loop through our parsing options and apply them
for _, option := range options {
option(p)
option((*Parser[Claims])(p))
}

return p
}

// Parse parses, validates, verifies the signature and returns the parsed token.
// keyFunc will receive the parsed token and should return the key for validating.
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
/*func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], error) {
return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc)
}
}*/

// ParseWithClaims parses, validates, and verifies like Parse, but supplies a default object implementing the Claims
// interface. This provides default values which can be overridden and allows a caller to use their own type, rather
Expand All @@ -47,7 +47,7 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
// Note: If you provide a custom claim implementation that embeds one of the standard claims (such as RegisteredClaims),
// make sure that a) you either embed a non-pointer version of the claims or b) if you are using a pointer, allocate the
// proper memory for it before passing in the overall claims, otherwise you might run into a panic.
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
func (p *Parser[T]) ParseWithClaims(tokenString string, claims T, keyFunc Keyfunc[T]) (*Token[T], error) {
token, parts, err := p.ParseUnverified(tokenString, claims)
if err != nil {
return token, err
Expand Down Expand Up @@ -89,7 +89,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
if !p.skipClaimsValidation {
// Make sure we have at least a default validator
if p.validator == nil {
p.validator = newValidator()
p.validator = newValidator[T]()
}

if err := p.validator.Validate(claims); err != nil {
Expand Down Expand Up @@ -124,13 +124,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
//
// It's only ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from it.
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
func (p *Parser[T]) ParseUnverified(tokenString string, claims T) (token *Token[T], parts []string, err error) {
parts = strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}

token = &Token{Raw: tokenString}
token = &Token[T]{Raw: tokenString}

// parse Header
var headerBytes []byte
Expand All @@ -156,11 +156,11 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
dec.UseNumber()
}
// JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
/*if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}
} else {*/
err = dec.Decode(&claims)
//}
// Handle decode error
if err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
Expand Down
20 changes: 10 additions & 10 deletions parser_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,34 @@ import "time"
// ParserOption is used to implement functional-style options that modify the behavior of the parser. To add
// new options, just create a function (ideally beginning with With or Without) that returns an anonymous function that
// takes a *Parser type as input and manipulates its configuration accordingly.
type ParserOption func(*Parser)
type ParserOption func(*Parser[Claims])

// WithValidMethods is an option to supply algorithm methods that the parser will check. Only those methods will be considered valid.
// It is heavily encouraged to use this option in order to prevent attacks such as https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/.
func WithValidMethods(methods []string) ParserOption {
return func(p *Parser) {
return func(p *Parser[Claims]) {
p.validMethods = methods
}
}

// WithJSONNumber is an option to configure the underlying JSON parser with UseNumber
func WithJSONNumber() ParserOption {
return func(p *Parser) {
return func(p *Parser[Claims]) {
p.useJSONNumber = true
}
}

// WithoutClaimsValidation is an option to disable claims validation. This option should only be used if you exactly know
// what you are doing.
func WithoutClaimsValidation() ParserOption {
return func(p *Parser) {
return func(p *Parser[Claims]) {
p.skipClaimsValidation = true
}
}

// WithLeeway returns the ParserOption for specifying the leeway window.
func WithLeeway(leeway time.Duration) ParserOption {
return func(p *Parser) {
return func(p *Parser[Claims]) {
p.validator.leeway = leeway
}
}
Expand All @@ -41,15 +41,15 @@ func WithLeeway(leeway time.Duration) ParserOption {
// primary use-case for this is testing. If you are looking for a way to account
// for clock-skew, WithLeeway should be used instead.
func WithTimeFunc(f func() time.Time) ParserOption {
return func(p *Parser) {
return func(p *Parser[Claims]) {
p.validator.timeFunc = f
}
}

// WithIssuedAt returns the ParserOption to enable verification
// of issued-at.
func WithIssuedAt() ParserOption {
return func(p *Parser) {
return func(p *Parser[Claims]) {
p.validator.verifyIat = true
}
}
Expand All @@ -62,7 +62,7 @@ func WithIssuedAt() ParserOption {
// application-specific. Since this validation API is helping developers in
// writing secure application, we decided to REQUIRE the existence of the claim.
func WithAudience(aud string) ParserOption {
return func(p *Parser) {
return func(p *Parser[Claims]) {
p.validator.expectedAud = aud
}
}
Expand All @@ -75,7 +75,7 @@ func WithAudience(aud string) ParserOption {
// application-specific. Since this validation API is helping developers in
// writing secure application, we decided to REQUIRE the existence of the claim.
func WithIssuer(iss string) ParserOption {
return func(p *Parser) {
return func(p *Parser[Claims]) {
p.validator.expectedIss = iss
}
}
Expand All @@ -88,7 +88,7 @@ func WithIssuer(iss string) ParserOption {
// application-specific. Since this validation API is helping developers in
// writing secure application, we decided to REQUIRE the existence of the claim.
func WithSubject(sub string) ParserOption {
return func(p *Parser) {
return func(p *Parser[Claims]) {
p.validator.expectedSub = sub
}
}
Loading