Skip to content

Support concurrent CSRF cookies by using a prefix of nonce #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 23, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Move ValidateState out and make CSRF cookies last 1h
Michal Witkowski committed Sep 3, 2020
commit 3c7bab9ab7501fe8383b95b7aa033565706696ad
29 changes: 15 additions & 14 deletions internal/auth.go
Original file line number Diff line number Diff line change
@@ -175,6 +175,10 @@ func buildCSRFCookieName(nonce string) string {
}

// MakeCSRFCookie makes a csrf cookie (used during login only)
//
// Note, CSRF cookies live shorter than auth cookies, a fixed 1h.
// That's because some CSRF cookies may belong to auth flows that don't complete
// and thus may not get cleared by ClearCookie.
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{
Name: buildCSRFCookieName(nonce),
@@ -183,7 +187,7 @@ func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
Domain: csrfCookieDomain(r),
HttpOnly: true,
Secure: !config.InsecureCookie,
Expires: cookieExpiry(),
Expires: time.Now().Local().Add(time.Hour * 1),
}
}

@@ -201,12 +205,7 @@ func ClearCSRFCookie(r *http.Request, c *http.Cookie) *http.Cookie {
}

// FindCSRFCookie extracts the CSRF cookie from the request based on state.
func FindCSRFCookie(r *http.Request) (c *http.Cookie, err error) {
state := r.URL.Query().Get("state")
if len(state) < 34 {
return nil, errors.New("Invalid CSRF state value")
}

func FindCSRFCookie(r *http.Request, state string) (c *http.Cookie, err error) {
// Check for CSRF cookie
c, err = r.Cookie(buildCSRFCookieName(state))
if err != nil {
@@ -216,17 +215,11 @@ func FindCSRFCookie(r *http.Request) (c *http.Cookie, err error) {
}

// ValidateCSRFCookie validates the csrf cookie against state
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) {
state := r.URL.Query().Get("state")

func ValidateCSRFCookie(c *http.Cookie, state string) (valid bool, provider string, redirect string, err error) {
if len(c.Value) != 32 {
return false, "", "", errors.New("Invalid CSRF cookie value")
}

if len(state) < 34 {
return false, "", "", errors.New("Invalid CSRF state value")
}

// Check nonce match
if c.Value != state[:32] {
return false, "", "", errors.New("CSRF cookie does not match state")
@@ -248,6 +241,14 @@ func MakeState(r *http.Request, p provider.Provider, nonce string) string {
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
}

// ValidateState checks whether the state is of right length.
func ValidateState(state string) error {
if len(state) < 34 {
return errors.New("Invalid CSRF state value")
}
return nil
}

// Nonce generates a random nonce
func Nonce() (error, string) {
nonce := make([]byte, 16)
46 changes: 23 additions & 23 deletions internal/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package tfa

import (
"fmt"
"net/http"
"net/url"
"strings"
@@ -251,56 +250,57 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
c := &http.Cookie{}

newCsrfRequest := func(state string) *http.Request {
u := fmt.Sprintf("http://example.com?state=%s", state)
r, _ := http.NewRequest("GET", u, nil)
return r
}
state := ""

// Should require 32 char string
r := newCsrfRequest("")
state = ""
c.Value = ""
valid, _, _, err := ValidateCSRFCookie(r, c)
valid, _, _, err := ValidateCSRFCookie(c, state)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF cookie value", err.Error())
}
c.Value = "123456789012345678901234567890123"
valid, _, _, err = ValidateCSRFCookie(r, c)
valid, _, _, err = ValidateCSRFCookie(c, state)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF cookie value", err.Error())
}

// Should require valid state
r = newCsrfRequest("12345678901234567890123456789012:")
c.Value = "12345678901234567890123456789012"
valid, _, _, err = ValidateCSRFCookie(r, c)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF state value", err.Error())
}

// Should require provider
r = newCsrfRequest("12345678901234567890123456789012:99")
state = "12345678901234567890123456789012:99"
c.Value = "12345678901234567890123456789012"
valid, _, _, err = ValidateCSRFCookie(r, c)
valid, _, _, err = ValidateCSRFCookie(c, state)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF state format", err.Error())
}

// Should allow valid state
r = newCsrfRequest("12345678901234567890123456789012:p99:url123")
state = "12345678901234567890123456789012:p99:url123"
c.Value = "12345678901234567890123456789012"
valid, provider, redirect, err := ValidateCSRFCookie(r, c)
valid, provider, redirect, err := ValidateCSRFCookie(c, state)
assert.True(valid, "valid request should return valid")
assert.Nil(err, "valid request should not return an error")
assert.Equal("p99", provider, "valid request should return correct provider")
assert.Equal("url123", redirect, "valid request should return correct redirect")
}

func TestValidateState(t *testing.T) {
assert := assert.New(t)

// Should require valid state
state := "12345678901234567890123456789012:"
err := ValidateState(state)
if assert.Error(err) {
assert.Equal("Invalid CSRF state value", err.Error())
}
// Should pass this state
state = "12345678901234567890123456789012:p99:url123"
err = ValidateState(state)
assert.Nil(err, "valid request should not return an error")
}

func TestMakeState(t *testing.T) {
assert := assert.New(t)

12 changes: 10 additions & 2 deletions internal/server.go
Original file line number Diff line number Diff line change
@@ -120,17 +120,25 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Logging setup
logger := s.logger(r, "AuthCallback", "default", "Handling callback")
state := r.URL.Query().Get("state")
if err := ValidateState(state); err != nil {
logger.WithFields(logrus.Fields{
"error": err,
}).Warn("Bad CSRF state")
http.Error(w, "Not authorized", 401)
return
}

// Check for CSRF cookie
c, err := FindCSRFCookie(r)
c, err := FindCSRFCookie(r, state)
if err != nil {
logger.Info("Missing csrf cookie")
http.Error(w, "Not authorized", 401)
return
}

// Validate state
valid, providerName, redirect, err := ValidateCSRFCookie(r, c)
valid, providerName, redirect, err := ValidateCSRFCookie(c, state)
if !valid {
logger.WithFields(logrus.Fields{
"error": err,