@@ -37,7 +37,7 @@ func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
37
37
newReq := req .Clone (req .Context ())
38
38
39
39
// Get current token (may refresh if needed)
40
- token , err := t .getValidToken ()
40
+ token , err := t .getValidToken (authTokenIdentifier )
41
41
if err != nil {
42
42
return nil , fmt .Errorf ("failed to get valid OAuth token: %w" , err )
43
43
}
@@ -58,21 +58,21 @@ func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
58
58
}
59
59
60
60
// getValidToken returns a valid access token, refreshing if necessary
61
- func (t * OAuthTransport ) getValidToken () (string , error ) {
61
+ func (t * OAuthTransport ) getValidToken (tokenIdentifier string ) (string , error ) {
62
62
storage , err := common .NewOAuthStorage ()
63
63
if err != nil {
64
64
return "" , fmt .Errorf ("failed to create OAuth storage: %w" , err )
65
65
}
66
66
67
67
// Load stored token
68
- token , err := storage .LoadToken ("claude" )
68
+ token , err := storage .LoadToken (tokenIdentifier )
69
69
if err != nil {
70
70
return "" , fmt .Errorf ("failed to load stored token: %w" , err )
71
71
}
72
72
// If no token exists, run OAuth flow
73
73
if token == nil {
74
74
fmt .Println ("No OAuth token found, initiating authentication..." )
75
- newAccessToken , err := RunOAuthFlow ()
75
+ newAccessToken , err := RunOAuthFlow (tokenIdentifier )
76
76
if err != nil {
77
77
return "" , fmt .Errorf ("failed to authenticate: %w" , err )
78
78
}
@@ -82,11 +82,11 @@ func (t *OAuthTransport) getValidToken() (string, error) {
82
82
// Check if token needs refresh (5 minute buffer)
83
83
if token .IsExpired (5 ) {
84
84
fmt .Println ("OAuth token expired, refreshing..." )
85
- newAccessToken , err := RefreshToken ()
85
+ newAccessToken , err := RefreshToken (tokenIdentifier )
86
86
if err != nil {
87
87
// If refresh fails, try re-authentication
88
88
fmt .Println ("Token refresh failed, re-authenticating..." )
89
- newAccessToken , err = RunOAuthFlow ()
89
+ newAccessToken , err = RunOAuthFlow (tokenIdentifier )
90
90
if err != nil {
91
91
return "" , fmt .Errorf ("failed to refresh or re-authenticate: %w" , err )
92
92
}
@@ -129,7 +129,28 @@ func openBrowser(url string) {
129
129
}
130
130
131
131
// RunOAuthFlow executes the complete OAuth authorization flow
132
- func RunOAuthFlow () (token string , err error ) {
132
+ func RunOAuthFlow (tokenIdentifier string ) (token string , err error ) {
133
+ // First check if we have an existing token that can be refreshed
134
+ storage , err := common .NewOAuthStorage ()
135
+ if err == nil {
136
+ existingToken , err := storage .LoadToken (tokenIdentifier )
137
+ if err == nil && existingToken != nil {
138
+ // If token exists but is expired, try refreshing first
139
+ if existingToken .IsExpired (5 ) {
140
+ fmt .Println ("Found expired OAuth token, attempting refresh..." )
141
+ refreshedToken , refreshErr := RefreshToken (tokenIdentifier )
142
+ if refreshErr == nil {
143
+ fmt .Println ("Token refresh successful" )
144
+ return refreshedToken , nil
145
+ }
146
+ fmt .Printf ("Token refresh failed (%v), proceeding with full OAuth flow...\n " , refreshErr )
147
+ } else {
148
+ // Token exists and is still valid
149
+ return existingToken .AccessToken , nil
150
+ }
151
+ }
152
+ }
153
+
133
154
verifier , challenge , err := generatePKCE ()
134
155
if err != nil {
135
156
return
@@ -171,12 +192,12 @@ func RunOAuthFlow() (token string, err error) {
171
192
"code_verifier" : verifier ,
172
193
}
173
194
174
- token , err = exchangeToken (tokenReq )
195
+ token , err = exchangeToken (tokenIdentifier , tokenReq )
175
196
return
176
197
}
177
198
178
199
// exchangeToken exchanges authorization code for access token
179
- func exchangeToken (params map [string ]string ) (token string , err error ) {
200
+ func exchangeToken (tokenIdentifier string , params map [string ]string ) (token string , err error ) {
180
201
reqBody , err := json .Marshal (params )
181
202
if err != nil {
182
203
return
@@ -219,7 +240,7 @@ func exchangeToken(params map[string]string) (token string, err error) {
219
240
Scope : result .Scope ,
220
241
}
221
242
222
- if err = storage .SaveToken ("claude" , oauthToken ); err != nil {
243
+ if err = storage .SaveToken (tokenIdentifier , oauthToken ); err != nil {
223
244
return result .AccessToken , fmt .Errorf ("failed to save OAuth token: %w" , err )
224
245
}
225
246
@@ -228,14 +249,14 @@ func exchangeToken(params map[string]string) (token string, err error) {
228
249
}
229
250
230
251
// RefreshToken refreshes an expired OAuth token using the refresh token
231
- func RefreshToken () (string , error ) {
252
+ func RefreshToken (tokenIdentifier string ) (string , error ) {
232
253
storage , err := common .NewOAuthStorage ()
233
254
if err != nil {
234
255
return "" , fmt .Errorf ("failed to create OAuth storage: %w" , err )
235
256
}
236
257
237
258
// Load existing token
238
- token , err := storage .LoadToken ("claude" )
259
+ token , err := storage .LoadToken (tokenIdentifier )
239
260
if err != nil {
240
261
return "" , fmt .Errorf ("failed to load stored token: %w" , err )
241
262
}
@@ -292,7 +313,7 @@ func RefreshToken() (string, error) {
292
313
newToken .RefreshToken = token .RefreshToken
293
314
}
294
315
295
- if err = storage .SaveToken ("claude" , newToken ); err != nil {
316
+ if err = storage .SaveToken (tokenIdentifier , newToken ); err != nil {
296
317
return "" , fmt .Errorf ("failed to save refreshed token: %w" , err )
297
318
}
298
319
0 commit comments