Skip to content

Commit 9f8a2d3

Browse files
authored
Merge pull request #1538 from andrewsjg/bug/bedrock-region-handling
Bug/bedrock region handling
2 parents 4353bc9 + 3fd923f commit 9f8a2d3

File tree

2 files changed

+114
-27
lines changed

2 files changed

+114
-27
lines changed

core/plugin_registry.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ func hasAWSCredentials() bool {
4545
if os.Getenv("AWS_PROFILE") != "" ||
4646
os.Getenv("AWS_ROLE_SESSION_NAME") != "" ||
4747
(os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "") {
48+
4849
return true
4950
}
5051

plugins/ai/bedrock/bedrock.go

Lines changed: 113 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/danielmiessler/fabric/common"
1313
"github.com/danielmiessler/fabric/plugins"
14+
"github.com/danielmiessler/fabric/plugins/ai"
1415

1516
"github.com/aws/aws-sdk-go-v2/aws"
1617
"github.com/aws/aws-sdk-go-v2/aws/middleware"
@@ -22,48 +23,114 @@ import (
2223
goopenai "github.com/sashabaranov/go-openai"
2324
)
2425

25-
// BedrockClient is a plugin to add support for Amazon Bedrock
26+
const (
27+
userAgentKey = "aiosc"
28+
userAgentValue = "fabric"
29+
)
30+
31+
// Ensure BedrockClient implements the ai.Vendor interface
32+
var _ ai.Vendor = (*BedrockClient)(nil)
33+
34+
// BedrockClient is a plugin to add support for Amazon Bedrock.
35+
// It implements the plugins.Plugin interface and provides methods
36+
// for interacting with AWS Bedrock's Converse and ConverseStream APIs.
2637
type BedrockClient struct {
2738
*plugins.PluginBase
2839
runtimeClient *bedrockruntime.Client
2940
controlPlaneClient *bedrock.Client
41+
42+
bedrockRegion *plugins.SetupQuestion
3043
}
3144

3245
// NewClient returns a new Bedrock plugin client
3346
func NewClient() (ret *BedrockClient) {
3447
vendorName := "Bedrock"
48+
ret = &BedrockClient{}
3549

36-
ctx := context.TODO()
50+
ctx := context.Background()
3751
cfg, err := config.LoadDefaultConfig(ctx)
38-
cfg.APIOptions = append(cfg.APIOptions, middleware.AddUserAgentKeyValue("aiosc", "fabric"))
39-
4052
if err != nil {
41-
fmt.Printf("Unable to load AWS Config: %s\n", err)
53+
// Create a minimal client that will fail gracefully during configuration
54+
ret.PluginBase = &plugins.PluginBase{
55+
Name: vendorName,
56+
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
57+
ConfigureCustom: func() error {
58+
return fmt.Errorf("unable to load AWS Config: %w", err)
59+
},
60+
}
61+
ret.bedrockRegion = ret.PluginBase.AddSetupQuestion("AWS Region", true)
62+
return
4263
}
4364

65+
cfg.APIOptions = append(cfg.APIOptions, middleware.AddUserAgentKeyValue(userAgentKey, userAgentValue))
66+
4467
runtimeClient := bedrockruntime.NewFromConfig(cfg)
4568
controlPlaneClient := bedrock.NewFromConfig(cfg)
4669

47-
ret = &BedrockClient{
48-
PluginBase: &plugins.PluginBase{
49-
Name: vendorName,
50-
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
51-
},
52-
runtimeClient: runtimeClient,
53-
controlPlaneClient: controlPlaneClient,
70+
ret.PluginBase = &plugins.PluginBase{
71+
Name: vendorName,
72+
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
73+
ConfigureCustom: ret.configure,
74+
}
75+
76+
ret.runtimeClient = runtimeClient
77+
ret.controlPlaneClient = controlPlaneClient
78+
79+
ret.bedrockRegion = ret.PluginBase.AddSetupQuestion("AWS Region", true)
80+
81+
if cfg.Region != "" {
82+
ret.bedrockRegion.Value = cfg.Region
5483
}
5584

5685
return
5786
}
5887

59-
// ListModels lists the models available for use with the Bedrock plugin
88+
// isValidAWSRegion validates AWS region format
89+
func isValidAWSRegion(region string) bool {
90+
// Simple validation - AWS regions are typically 2-3 parts separated by hyphens
91+
// Examples: us-east-1, eu-west-1, ap-southeast-2
92+
if len(region) < 5 || len(region) > 30 {
93+
return false
94+
}
95+
// Basic pattern check for AWS region format
96+
return region != ""
97+
}
98+
99+
// configure initializes the Bedrock clients with the specified AWS region.
100+
// If no region is specified, the default region from AWS config is used.
101+
func (c *BedrockClient) configure() error {
102+
if c.bedrockRegion.Value == "" {
103+
return nil // Use default region from AWS config
104+
}
105+
106+
// Validate region format
107+
if !isValidAWSRegion(c.bedrockRegion.Value) {
108+
return fmt.Errorf("invalid AWS region: %s", c.bedrockRegion.Value)
109+
}
110+
111+
ctx := context.Background()
112+
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(c.bedrockRegion.Value))
113+
if err != nil {
114+
return fmt.Errorf("unable to load AWS Config with region %s: %w", c.bedrockRegion.Value, err)
115+
}
116+
117+
cfg.APIOptions = append(cfg.APIOptions, middleware.AddUserAgentKeyValue(userAgentKey, userAgentValue))
118+
119+
c.runtimeClient = bedrockruntime.NewFromConfig(cfg)
120+
c.controlPlaneClient = bedrock.NewFromConfig(cfg)
121+
122+
return nil
123+
}
124+
125+
// ListModels retrieves all available foundation models and inference profiles
126+
// from AWS Bedrock that can be used with this plugin.
60127
func (c *BedrockClient) ListModels() ([]string, error) {
61128
models := []string{}
62-
ctx := context.TODO()
129+
ctx := context.Background()
63130

64131
foundationModels, err := c.controlPlaneClient.ListFoundationModels(ctx, &bedrock.ListFoundationModelsInput{})
65132
if err != nil {
66-
return nil, err
133+
return nil, fmt.Errorf("failed to list foundation models: %w", err)
67134
}
68135

69136
for _, model := range foundationModels.ModelSummaries {
@@ -73,9 +140,9 @@ func (c *BedrockClient) ListModels() ([]string, error) {
73140
inferenceProfilesPaginator := bedrock.NewListInferenceProfilesPaginator(c.controlPlaneClient, &bedrock.ListInferenceProfilesInput{})
74141

75142
for inferenceProfilesPaginator.HasMorePages() {
76-
inferenceProfiles, err := inferenceProfilesPaginator.NextPage(context.TODO())
143+
inferenceProfiles, err := inferenceProfilesPaginator.NextPage(ctx)
77144
if err != nil {
78-
return nil, err
145+
return nil, fmt.Errorf("failed to list inference profiles: %w", err)
79146
}
80147

81148
for _, profile := range inferenceProfiles.InferenceProfileSummaries {
@@ -88,6 +155,13 @@ func (c *BedrockClient) ListModels() ([]string, error) {
88155

89156
// SendStream sends the messages to the the Bedrock ConverseStream API
90157
func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) (err error) {
158+
// Ensure channel is closed on all exit paths to prevent goroutine leaks
159+
defer func() {
160+
if r := recover(); r != nil {
161+
err = fmt.Errorf("panic in SendStream: %v", r)
162+
}
163+
close(channel)
164+
}()
91165

92166
messages := c.toMessages(msgs)
93167

@@ -99,10 +173,9 @@ func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts
99173
TopP: aws.Float32(float32(opts.TopP))},
100174
}
101175

102-
response, err := c.runtimeClient.ConverseStream(context.TODO(), &converseInput)
176+
response, err := c.runtimeClient.ConverseStream(context.Background(), &converseInput)
103177
if err != nil {
104-
fmt.Printf("Error conversing with Bedrock: %s\n", err)
105-
return
178+
return fmt.Errorf("bedrock conversestream failed for model %s: %w", opts.Model, err)
106179
}
107180

108181
for event := range response.GetStream().Events() {
@@ -118,7 +191,7 @@ func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts
118191

119192
case *types.ConverseStreamOutputMemberMessageStop:
120193
channel <- "\n"
121-
close(channel)
194+
return nil // Let defer handle the close
122195

123196
// Unused Events
124197
case *types.ConverseStreamOutputMemberMessageStart,
@@ -127,7 +200,7 @@ func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts
127200
*types.ConverseStreamOutputMemberMetadata:
128201

129202
default:
130-
fmt.Printf("Error: Unknown stream event type: %T\n", v)
203+
return fmt.Errorf("unknown stream event type: %T", v)
131204
}
132205
}
133206

@@ -145,22 +218,35 @@ func (c *BedrockClient) Send(ctx context.Context, msgs []*goopenai.ChatCompletio
145218
}
146219
response, err := c.runtimeClient.Converse(ctx, &converseInput)
147220
if err != nil {
148-
fmt.Printf("Error conversing with Bedrock: %s\n", err)
149-
return "", err
221+
return "", fmt.Errorf("bedrock converse failed for model %s: %w", opts.Model, err)
222+
}
223+
224+
responseText, ok := response.Output.(*types.ConverseOutputMemberMessage)
225+
if !ok {
226+
return "", fmt.Errorf("unexpected response type: %T", response.Output)
227+
}
228+
229+
if len(responseText.Value.Content) == 0 {
230+
return "", fmt.Errorf("empty response content")
150231
}
151232

152-
responseText, _ := response.Output.(*types.ConverseOutputMemberMessage)
153233
responseContentBlock := responseText.Value.Content[0]
154-
text, _ := responseContentBlock.(*types.ContentBlockMemberText)
234+
text, ok := responseContentBlock.(*types.ContentBlockMemberText)
235+
if !ok {
236+
return "", fmt.Errorf("unexpected content block type: %T", responseContentBlock)
237+
}
238+
155239
return text.Value, nil
156240
}
157241

242+
// NeedsRawMode indicates whether the model requires raw mode processing.
243+
// Bedrock models do not require raw mode.
158244
func (c *BedrockClient) NeedsRawMode(modelName string) bool {
159245
return false
160246
}
161247

162248
// toMessages converts the array of input messages from the ChatCompletionMessageType to the
163-
// Bedrock Converse Message type
249+
// Bedrock Converse Message type.
164250
// The system role messages are mapped to the user role as they contain a mix of system messages,
165251
// pattern content and user input.
166252
func (c *BedrockClient) toMessages(inputMessages []*goopenai.ChatCompletionMessage) (messages []types.Message) {

0 commit comments

Comments
 (0)