Skip to content

Commit dc63e0d

Browse files
authored
Merge pull request #1593 from ksylvan/0707-claude-oauth-improvement
Refactor: Generalize OAuth flow for improved token handling.
2 parents 47cf24e + 75842d8 commit dc63e0d

File tree

3 files changed

+474
-19
lines changed

3 files changed

+474
-19
lines changed

plugins/ai/anthropic/anthropic.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ const webSearchToolName = "web_search"
1919
const webSearchToolType = "web_search_20250305"
2020
const sourcesHeader = "## Sources"
2121

22-
const vendorTokenIdentifier = "claude"
22+
const authTokenIdentifier = "claude"
2323

2424
func NewClient() (ret *Client) {
2525
vendorName := "Anthropic"
@@ -65,15 +65,15 @@ func (an *Client) IsConfigured() bool {
6565
}
6666

6767
// If no valid token exists, automatically run OAuth flow
68-
if !storage.HasValidToken(vendorTokenIdentifier, 5) {
68+
if !storage.HasValidToken(authTokenIdentifier, 5) {
6969
fmt.Println("OAuth enabled but no valid token found. Starting authentication...")
70-
_, err := RunOAuthFlow()
70+
_, err := RunOAuthFlow(authTokenIdentifier)
7171
if err != nil {
7272
fmt.Printf("OAuth authentication failed: %v\n", err)
7373
return false
7474
}
7575
// After successful OAuth flow, check again
76-
return storage.HasValidToken("claude", 5)
76+
return storage.HasValidToken(authTokenIdentifier, 5)
7777
}
7878

7979
return true
@@ -107,9 +107,9 @@ func (an *Client) Setup() (err error) {
107107
return err
108108
}
109109

110-
if !storage.HasValidToken("claude", 5) {
110+
if !storage.HasValidToken(authTokenIdentifier, 5) {
111111
// No valid token, run OAuth flow
112-
if _, err = RunOAuthFlow(); err != nil {
112+
if _, err = RunOAuthFlow(authTokenIdentifier); err != nil {
113113
return err
114114
}
115115
}

plugins/ai/anthropic/oauth.go

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
3737
newReq := req.Clone(req.Context())
3838

3939
// Get current token (may refresh if needed)
40-
token, err := t.getValidToken()
40+
token, err := t.getValidToken(authTokenIdentifier)
4141
if err != nil {
4242
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
4343
}
@@ -58,21 +58,21 @@ func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
5858
}
5959

6060
// getValidToken returns a valid access token, refreshing if necessary
61-
func (t *OAuthTransport) getValidToken() (string, error) {
61+
func (t *OAuthTransport) getValidToken(tokenIdentifier string) (string, error) {
6262
storage, err := common.NewOAuthStorage()
6363
if err != nil {
6464
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
6565
}
6666

6767
// Load stored token
68-
token, err := storage.LoadToken("claude")
68+
token, err := storage.LoadToken(tokenIdentifier)
6969
if err != nil {
7070
return "", fmt.Errorf("failed to load stored token: %w", err)
7171
}
7272
// If no token exists, run OAuth flow
7373
if token == nil {
7474
fmt.Println("No OAuth token found, initiating authentication...")
75-
newAccessToken, err := RunOAuthFlow()
75+
newAccessToken, err := RunOAuthFlow(tokenIdentifier)
7676
if err != nil {
7777
return "", fmt.Errorf("failed to authenticate: %w", err)
7878
}
@@ -82,11 +82,11 @@ func (t *OAuthTransport) getValidToken() (string, error) {
8282
// Check if token needs refresh (5 minute buffer)
8383
if token.IsExpired(5) {
8484
fmt.Println("OAuth token expired, refreshing...")
85-
newAccessToken, err := RefreshToken()
85+
newAccessToken, err := RefreshToken(tokenIdentifier)
8686
if err != nil {
8787
// If refresh fails, try re-authentication
8888
fmt.Println("Token refresh failed, re-authenticating...")
89-
newAccessToken, err = RunOAuthFlow()
89+
newAccessToken, err = RunOAuthFlow(tokenIdentifier)
9090
if err != nil {
9191
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
9292
}
@@ -129,7 +129,28 @@ func openBrowser(url string) {
129129
}
130130

131131
// RunOAuthFlow executes the complete OAuth authorization flow
132-
func RunOAuthFlow() (token string, err error) {
132+
func RunOAuthFlow(tokenIdentifier string) (token string, err error) {
133+
// First check if we have an existing token that can be refreshed
134+
storage, err := common.NewOAuthStorage()
135+
if err == nil {
136+
existingToken, err := storage.LoadToken(tokenIdentifier)
137+
if err == nil && existingToken != nil {
138+
// If token exists but is expired, try refreshing first
139+
if existingToken.IsExpired(5) {
140+
fmt.Println("Found expired OAuth token, attempting refresh...")
141+
refreshedToken, refreshErr := RefreshToken(tokenIdentifier)
142+
if refreshErr == nil {
143+
fmt.Println("Token refresh successful")
144+
return refreshedToken, nil
145+
}
146+
fmt.Printf("Token refresh failed (%v), proceeding with full OAuth flow...\n", refreshErr)
147+
} else {
148+
// Token exists and is still valid
149+
return existingToken.AccessToken, nil
150+
}
151+
}
152+
}
153+
133154
verifier, challenge, err := generatePKCE()
134155
if err != nil {
135156
return
@@ -171,12 +192,12 @@ func RunOAuthFlow() (token string, err error) {
171192
"code_verifier": verifier,
172193
}
173194

174-
token, err = exchangeToken(tokenReq)
195+
token, err = exchangeToken(tokenIdentifier, tokenReq)
175196
return
176197
}
177198

178199
// exchangeToken exchanges authorization code for access token
179-
func exchangeToken(params map[string]string) (token string, err error) {
200+
func exchangeToken(tokenIdentifier string, params map[string]string) (token string, err error) {
180201
reqBody, err := json.Marshal(params)
181202
if err != nil {
182203
return
@@ -219,7 +240,7 @@ func exchangeToken(params map[string]string) (token string, err error) {
219240
Scope: result.Scope,
220241
}
221242

222-
if err = storage.SaveToken("claude", oauthToken); err != nil {
243+
if err = storage.SaveToken(tokenIdentifier, oauthToken); err != nil {
223244
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
224245
}
225246

@@ -228,14 +249,14 @@ func exchangeToken(params map[string]string) (token string, err error) {
228249
}
229250

230251
// RefreshToken refreshes an expired OAuth token using the refresh token
231-
func RefreshToken() (string, error) {
252+
func RefreshToken(tokenIdentifier string) (string, error) {
232253
storage, err := common.NewOAuthStorage()
233254
if err != nil {
234255
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
235256
}
236257

237258
// Load existing token
238-
token, err := storage.LoadToken("claude")
259+
token, err := storage.LoadToken(tokenIdentifier)
239260
if err != nil {
240261
return "", fmt.Errorf("failed to load stored token: %w", err)
241262
}
@@ -292,7 +313,7 @@ func RefreshToken() (string, error) {
292313
newToken.RefreshToken = token.RefreshToken
293314
}
294315

295-
if err = storage.SaveToken("claude", newToken); err != nil {
316+
if err = storage.SaveToken(tokenIdentifier, newToken); err != nil {
296317
return "", fmt.Errorf("failed to save refreshed token: %w", err)
297318
}
298319

0 commit comments

Comments
 (0)