Skip to content

Commit c0ea25f

Browse files
authored
Merge pull request #1603 from ksylvan/0711-together-ai-implementation
Together AI Support with OpenAI Fallback Mechanism Added
2 parents 6b07b33 + 87796d4 commit c0ea25f

File tree

2 files changed

+124
-0
lines changed

2 files changed

+124
-0
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package openai_compatible
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
"time"
11+
)
12+
13+
// Model represents a model returned by the API
14+
type Model struct {
15+
ID string `json:"id"`
16+
}
17+
18+
// ErrorResponseLimit defines the maximum length of error response bodies for truncation.
19+
const errorResponseLimit = 1024 // Limit for error response body size
20+
21+
// DirectlyGetModels is used to fetch models directly from the API
22+
// when the standard OpenAI SDK method fails due to a nonstandard format.
23+
// This is useful for providers like Together that return a direct array of models.
24+
func (c *Client) DirectlyGetModels(ctx context.Context) ([]string, error) {
25+
if ctx == nil {
26+
ctx = context.Background()
27+
}
28+
baseURL := c.ApiBaseURL.Value
29+
if baseURL == "" {
30+
return nil, fmt.Errorf("API base URL not configured for provider %s", c.GetName())
31+
}
32+
33+
// Build the /models endpoint URL
34+
fullURL, err := url.JoinPath(baseURL, "models")
35+
if err != nil {
36+
return nil, fmt.Errorf("failed to create models URL: %w", err)
37+
}
38+
39+
req, err := http.NewRequestWithContext(ctx, "GET", fullURL, nil)
40+
if err != nil {
41+
return nil, err
42+
}
43+
44+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.ApiKey.Value))
45+
req.Header.Set("Accept", "application/json")
46+
47+
// TODO: Consider reusing a single http.Client instance (e.g., as a field on Client) instead of allocating a new one for each request.
48+
49+
client := &http.Client{
50+
Timeout: 10 * time.Second,
51+
}
52+
resp, err := client.Do(req)
53+
if err != nil {
54+
return nil, err
55+
}
56+
defer resp.Body.Close()
57+
58+
if resp.StatusCode != http.StatusOK {
59+
// Read the response body for debugging
60+
bodyBytes, _ := io.ReadAll(resp.Body)
61+
bodyString := string(bodyBytes)
62+
if len(bodyString) > errorResponseLimit { // Truncate if too large
63+
bodyString = bodyString[:errorResponseLimit] + "..."
64+
}
65+
return nil, fmt.Errorf("unexpected status code: %d from provider %s, response body: %s",
66+
resp.StatusCode, c.GetName(), bodyString)
67+
}
68+
69+
// Read the response body once
70+
bodyBytes, err := io.ReadAll(resp.Body)
71+
if err != nil {
72+
return nil, err
73+
}
74+
75+
// Try to parse as an object with data field (OpenAI format)
76+
var openAIFormat struct {
77+
Data []Model `json:"data"`
78+
}
79+
// Try to parse as a direct array (Together format)
80+
var directArray []Model
81+
82+
if err := json.Unmarshal(bodyBytes, &openAIFormat); err == nil && len(openAIFormat.Data) > 0 {
83+
return extractModelIDs(openAIFormat.Data), nil
84+
}
85+
86+
if err := json.Unmarshal(bodyBytes, &directArray); err == nil && len(directArray) > 0 {
87+
return extractModelIDs(directArray), nil
88+
}
89+
90+
var truncatedBody string
91+
if len(bodyBytes) > errorResponseLimit {
92+
truncatedBody = string(bodyBytes[:errorResponseLimit]) + "..."
93+
} else {
94+
truncatedBody = string(bodyBytes)
95+
}
96+
return nil, fmt.Errorf("unable to parse models response; raw response: %s", truncatedBody)
97+
}
98+
99+
func extractModelIDs(models []Model) []string {
100+
modelIDs := make([]string, 0, len(models))
101+
for _, model := range models {
102+
modelIDs = append(modelIDs, model.ID)
103+
}
104+
return modelIDs
105+
}

internal/plugins/ai/openai_compatible/providers_config.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package openai_compatible
22

33
import (
4+
"context"
45
"os"
56
"strings"
67

@@ -31,6 +32,19 @@ func NewClient(providerConfig ProviderConfig) *Client {
3132
return client
3233
}
3334

35+
// ListModels overrides the default ListModels to handle different response formats
36+
func (c *Client) ListModels() ([]string, error) {
37+
// First try the standard OpenAI SDK approach
38+
models, err := c.Client.ListModels()
39+
if err == nil && len(models) > 0 { // only return if OpenAI SDK returns models
40+
return models, nil
41+
}
42+
43+
// TODO: Handle context properly in Fabric by accepting and propagating a context.Context
44+
// instead of creating a new one here.
45+
return c.DirectlyGetModels(context.Background())
46+
}
47+
3448
// ProviderMap is a map of provider name to ProviderConfig for O(1) lookup
3549
var ProviderMap = map[string]ProviderConfig{
3650
"AIML": {
@@ -83,6 +97,11 @@ var ProviderMap = map[string]ProviderConfig{
8397
BaseURL: "https://api.siliconflow.cn/v1",
8498
ImplementsResponses: false,
8599
},
100+
"Together": {
101+
Name: "Together",
102+
BaseURL: "https://api.together.xyz/v1",
103+
ImplementsResponses: false,
104+
},
86105
}
87106

88107
// GetProviderByName returns the provider configuration for a given name with O(1) lookup

0 commit comments

Comments
 (0)