Skip to content

Commit 369a0a8

Browse files
authored
Merge pull request #1565 from ksylvan/0701-claude-oauth-support
OAuth Authentication Support for Anthropic
2 parents acf1be7 + 8dc5343 commit 369a0a8

File tree

6 files changed

+753
-47
lines changed

6 files changed

+753
-47
lines changed

common/oauth_storage.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package common
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
"time"
9+
)
10+
11+
// OAuthToken represents stored OAuth token information
12+
type OAuthToken struct {
13+
AccessToken string `json:"access_token"`
14+
RefreshToken string `json:"refresh_token"`
15+
ExpiresAt int64 `json:"expires_at"`
16+
TokenType string `json:"token_type"`
17+
Scope string `json:"scope"`
18+
}
19+
20+
// IsExpired checks if the token is expired or will expire within the buffer time
21+
func (t *OAuthToken) IsExpired(bufferMinutes int) bool {
22+
if t.ExpiresAt == 0 {
23+
return true
24+
}
25+
bufferTime := time.Duration(bufferMinutes) * time.Minute
26+
return time.Now().Add(bufferTime).Unix() >= t.ExpiresAt
27+
}
28+
29+
// OAuthStorage handles persistent storage of OAuth tokens
30+
type OAuthStorage struct {
31+
configDir string
32+
}
33+
34+
// NewOAuthStorage creates a new OAuth storage instance
35+
func NewOAuthStorage() (*OAuthStorage, error) {
36+
homeDir, err := os.UserHomeDir()
37+
if err != nil {
38+
return nil, fmt.Errorf("failed to get user home directory: %w", err)
39+
}
40+
41+
configDir := filepath.Join(homeDir, ".config", "fabric")
42+
43+
// Ensure config directory exists
44+
if err := os.MkdirAll(configDir, 0755); err != nil {
45+
return nil, fmt.Errorf("failed to create config directory: %w", err)
46+
}
47+
48+
return &OAuthStorage{configDir: configDir}, nil
49+
}
50+
51+
// GetTokenPath returns the file path for a provider's OAuth token
52+
func (s *OAuthStorage) GetTokenPath(provider string) string {
53+
return filepath.Join(s.configDir, fmt.Sprintf(".%s_oauth", provider))
54+
}
55+
56+
// SaveToken saves an OAuth token to disk with proper permissions
57+
func (s *OAuthStorage) SaveToken(provider string, token *OAuthToken) error {
58+
tokenPath := s.GetTokenPath(provider)
59+
60+
// Marshal token to JSON
61+
data, err := json.MarshalIndent(token, "", " ")
62+
if err != nil {
63+
return fmt.Errorf("failed to marshal token: %w", err)
64+
}
65+
66+
// Write to temporary file first for atomic operation
67+
tempPath := tokenPath + ".tmp"
68+
if err := os.WriteFile(tempPath, data, 0600); err != nil {
69+
return fmt.Errorf("failed to write token file: %w", err)
70+
}
71+
72+
// Atomic rename
73+
if err := os.Rename(tempPath, tokenPath); err != nil {
74+
os.Remove(tempPath) // Clean up temp file
75+
return fmt.Errorf("failed to save token file: %w", err)
76+
}
77+
78+
return nil
79+
}
80+
81+
// LoadToken loads an OAuth token from disk
82+
func (s *OAuthStorage) LoadToken(provider string) (*OAuthToken, error) {
83+
tokenPath := s.GetTokenPath(provider)
84+
85+
// Check if file exists
86+
if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
87+
return nil, nil // No token stored
88+
}
89+
90+
// Read token file
91+
data, err := os.ReadFile(tokenPath)
92+
if err != nil {
93+
return nil, fmt.Errorf("failed to read token file: %w", err)
94+
}
95+
96+
// Unmarshal token
97+
var token OAuthToken
98+
if err := json.Unmarshal(data, &token); err != nil {
99+
return nil, fmt.Errorf("failed to parse token file: %w", err)
100+
}
101+
102+
return &token, nil
103+
}
104+
105+
// DeleteToken removes a stored OAuth token
106+
func (s *OAuthStorage) DeleteToken(provider string) error {
107+
tokenPath := s.GetTokenPath(provider)
108+
109+
if err := os.Remove(tokenPath); err != nil && !os.IsNotExist(err) {
110+
return fmt.Errorf("failed to delete token file: %w", err)
111+
}
112+
113+
return nil
114+
}
115+
116+
// HasValidToken checks if a valid (non-expired) token exists for a provider
117+
func (s *OAuthStorage) HasValidToken(provider string, bufferMinutes int) bool {
118+
token, err := s.LoadToken(provider)
119+
if err != nil || token == nil {
120+
return false
121+
}
122+
123+
return !token.IsExpired(bufferMinutes)
124+
}

common/oauth_storage_test.go

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
package common
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"testing"
7+
"time"
8+
)
9+
10+
func TestOAuthToken_IsExpired(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
expiresAt int64
14+
bufferMinutes int
15+
expected bool
16+
}{
17+
{
18+
name: "token not expired",
19+
expiresAt: time.Now().Unix() + 3600, // 1 hour from now
20+
bufferMinutes: 5,
21+
expected: false,
22+
},
23+
{
24+
name: "token expired",
25+
expiresAt: time.Now().Unix() - 3600, // 1 hour ago
26+
bufferMinutes: 5,
27+
expected: true,
28+
},
29+
{
30+
name: "token expires within buffer",
31+
expiresAt: time.Now().Unix() + 120, // 2 minutes from now
32+
bufferMinutes: 5,
33+
expected: true,
34+
},
35+
{
36+
name: "zero expiry time",
37+
expiresAt: 0,
38+
bufferMinutes: 5,
39+
expected: true,
40+
},
41+
}
42+
43+
for _, tt := range tests {
44+
t.Run(tt.name, func(t *testing.T) {
45+
token := &OAuthToken{ExpiresAt: tt.expiresAt}
46+
if got := token.IsExpired(tt.bufferMinutes); got != tt.expected {
47+
t.Errorf("IsExpired() = %v, want %v", got, tt.expected)
48+
}
49+
})
50+
}
51+
}
52+
53+
func TestOAuthStorage_SaveAndLoadToken(t *testing.T) {
54+
// Create temporary directory for testing
55+
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
56+
if err != nil {
57+
t.Fatalf("Failed to create temp dir: %v", err)
58+
}
59+
defer os.RemoveAll(tempDir)
60+
61+
// Create storage with custom config dir
62+
storage := &OAuthStorage{configDir: tempDir}
63+
64+
// Test token
65+
token := &OAuthToken{
66+
AccessToken: "test_access_token",
67+
RefreshToken: "test_refresh_token",
68+
ExpiresAt: time.Now().Unix() + 3600,
69+
TokenType: "Bearer",
70+
Scope: "test_scope",
71+
}
72+
73+
// Test saving token
74+
err = storage.SaveToken("test_provider", token)
75+
if err != nil {
76+
t.Fatalf("Failed to save token: %v", err)
77+
}
78+
79+
// Verify file exists and has correct permissions
80+
tokenPath := storage.GetTokenPath("test_provider")
81+
info, err := os.Stat(tokenPath)
82+
if err != nil {
83+
t.Fatalf("Token file not created: %v", err)
84+
}
85+
if info.Mode().Perm() != 0600 {
86+
t.Errorf("Token file has wrong permissions: %v, want 0600", info.Mode().Perm())
87+
}
88+
89+
// Test loading token
90+
loadedToken, err := storage.LoadToken("test_provider")
91+
if err != nil {
92+
t.Fatalf("Failed to load token: %v", err)
93+
}
94+
if loadedToken == nil {
95+
t.Fatal("Loaded token is nil")
96+
}
97+
98+
// Verify token data
99+
if loadedToken.AccessToken != token.AccessToken {
100+
t.Errorf("AccessToken mismatch: got %v, want %v", loadedToken.AccessToken, token.AccessToken)
101+
}
102+
if loadedToken.RefreshToken != token.RefreshToken {
103+
t.Errorf("RefreshToken mismatch: got %v, want %v", loadedToken.RefreshToken, token.RefreshToken)
104+
}
105+
if loadedToken.ExpiresAt != token.ExpiresAt {
106+
t.Errorf("ExpiresAt mismatch: got %v, want %v", loadedToken.ExpiresAt, token.ExpiresAt)
107+
}
108+
}
109+
110+
func TestOAuthStorage_LoadNonExistentToken(t *testing.T) {
111+
// Create temporary directory for testing
112+
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
113+
if err != nil {
114+
t.Fatalf("Failed to create temp dir: %v", err)
115+
}
116+
defer os.RemoveAll(tempDir)
117+
118+
storage := &OAuthStorage{configDir: tempDir}
119+
120+
// Try to load non-existent token
121+
token, err := storage.LoadToken("nonexistent")
122+
if err != nil {
123+
t.Fatalf("Unexpected error loading non-existent token: %v", err)
124+
}
125+
if token != nil {
126+
t.Error("Expected nil token for non-existent provider")
127+
}
128+
}
129+
130+
func TestOAuthStorage_DeleteToken(t *testing.T) {
131+
// Create temporary directory for testing
132+
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
133+
if err != nil {
134+
t.Fatalf("Failed to create temp dir: %v", err)
135+
}
136+
defer os.RemoveAll(tempDir)
137+
138+
storage := &OAuthStorage{configDir: tempDir}
139+
140+
// Create and save a token
141+
token := &OAuthToken{
142+
AccessToken: "test_token",
143+
RefreshToken: "test_refresh",
144+
ExpiresAt: time.Now().Unix() + 3600,
145+
}
146+
err = storage.SaveToken("test_provider", token)
147+
if err != nil {
148+
t.Fatalf("Failed to save token: %v", err)
149+
}
150+
151+
// Verify token exists
152+
tokenPath := storage.GetTokenPath("test_provider")
153+
if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
154+
t.Fatal("Token file should exist before deletion")
155+
}
156+
157+
// Delete token
158+
err = storage.DeleteToken("test_provider")
159+
if err != nil {
160+
t.Fatalf("Failed to delete token: %v", err)
161+
}
162+
163+
// Verify token is deleted
164+
if _, err := os.Stat(tokenPath); !os.IsNotExist(err) {
165+
t.Error("Token file should not exist after deletion")
166+
}
167+
168+
// Test deleting non-existent token (should not error)
169+
err = storage.DeleteToken("nonexistent")
170+
if err != nil {
171+
t.Errorf("Deleting non-existent token should not error: %v", err)
172+
}
173+
}
174+
175+
func TestOAuthStorage_HasValidToken(t *testing.T) {
176+
// Create temporary directory for testing
177+
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
178+
if err != nil {
179+
t.Fatalf("Failed to create temp dir: %v", err)
180+
}
181+
defer os.RemoveAll(tempDir)
182+
183+
storage := &OAuthStorage{configDir: tempDir}
184+
185+
// Test with no token
186+
if storage.HasValidToken("test_provider", 5) {
187+
t.Error("Should return false when no token exists")
188+
}
189+
190+
// Save valid token
191+
validToken := &OAuthToken{
192+
AccessToken: "valid_token",
193+
RefreshToken: "refresh_token",
194+
ExpiresAt: time.Now().Unix() + 3600, // 1 hour from now
195+
}
196+
err = storage.SaveToken("test_provider", validToken)
197+
if err != nil {
198+
t.Fatalf("Failed to save valid token: %v", err)
199+
}
200+
201+
// Test with valid token
202+
if !storage.HasValidToken("test_provider", 5) {
203+
t.Error("Should return true for valid token")
204+
}
205+
206+
// Save expired token
207+
expiredToken := &OAuthToken{
208+
AccessToken: "expired_token",
209+
RefreshToken: "refresh_token",
210+
ExpiresAt: time.Now().Unix() - 3600, // 1 hour ago
211+
}
212+
err = storage.SaveToken("expired_provider", expiredToken)
213+
if err != nil {
214+
t.Fatalf("Failed to save expired token: %v", err)
215+
}
216+
217+
// Test with expired token
218+
if storage.HasValidToken("expired_provider", 5) {
219+
t.Error("Should return false for expired token")
220+
}
221+
}
222+
223+
func TestOAuthStorage_GetTokenPath(t *testing.T) {
224+
storage := &OAuthStorage{configDir: "/test/config"}
225+
226+
expected := filepath.Join("/test/config", ".test_provider_oauth")
227+
actual := storage.GetTokenPath("test_provider")
228+
229+
if actual != expected {
230+
t.Errorf("GetTokenPath() = %v, want %v", actual, expected)
231+
}
232+
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ require (
2525
github.com/samber/lo v1.50.0
2626
github.com/sgaunet/perplexity-go/v2 v2.8.0
2727
github.com/stretchr/testify v1.10.0
28+
golang.org/x/oauth2 v0.30.0
2829
golang.org/x/text v0.26.0
2930
google.golang.org/api v0.236.0
3031
gopkg.in/yaml.v3 v3.0.1
@@ -108,7 +109,6 @@ require (
108109
golang.org/x/arch v0.18.0 // indirect
109110
golang.org/x/crypto v0.39.0 // indirect
110111
golang.org/x/net v0.41.0 // indirect
111-
golang.org/x/oauth2 v0.30.0 // indirect
112112
golang.org/x/sync v0.15.0 // indirect
113113
golang.org/x/sys v0.33.0 // indirect
114114
golang.org/x/time v0.12.0 // indirect

0 commit comments

Comments
 (0)