@@ -11,6 +11,7 @@ import (
11
11
12
12
"github.com/danielmiessler/fabric/common"
13
13
"github.com/danielmiessler/fabric/plugins"
14
+ "github.com/danielmiessler/fabric/plugins/ai"
14
15
15
16
"github.com/aws/aws-sdk-go-v2/aws"
16
17
"github.com/aws/aws-sdk-go-v2/aws/middleware"
@@ -22,48 +23,114 @@ import (
22
23
goopenai "github.com/sashabaranov/go-openai"
23
24
)
24
25
25
- // BedrockClient is a plugin to add support for Amazon Bedrock
26
+ const (
27
+ userAgentKey = "aiosc"
28
+ userAgentValue = "fabric"
29
+ )
30
+
31
+ // Ensure BedrockClient implements the ai.Vendor interface
32
+ var _ ai.Vendor = (* BedrockClient )(nil )
33
+
34
+ // BedrockClient is a plugin to add support for Amazon Bedrock.
35
+ // It implements the plugins.Plugin interface and provides methods
36
+ // for interacting with AWS Bedrock's Converse and ConverseStream APIs.
26
37
type BedrockClient struct {
27
38
* plugins.PluginBase
28
39
runtimeClient * bedrockruntime.Client
29
40
controlPlaneClient * bedrock.Client
41
+
42
+ bedrockRegion * plugins.SetupQuestion
30
43
}
31
44
32
45
// NewClient returns a new Bedrock plugin client
33
46
func NewClient () (ret * BedrockClient ) {
34
47
vendorName := "Bedrock"
48
+ ret = & BedrockClient {}
35
49
36
- ctx := context .TODO ()
50
+ ctx := context .Background ()
37
51
cfg , err := config .LoadDefaultConfig (ctx )
38
- cfg .APIOptions = append (cfg .APIOptions , middleware .AddUserAgentKeyValue ("aiosc" , "fabric" ))
39
-
40
52
if err != nil {
41
- fmt .Printf ("Unable to load AWS Config: %s\n " , err )
53
+ // Create a minimal client that will fail gracefully during configuration
54
+ ret .PluginBase = & plugins.PluginBase {
55
+ Name : vendorName ,
56
+ EnvNamePrefix : plugins .BuildEnvVariablePrefix (vendorName ),
57
+ ConfigureCustom : func () error {
58
+ return fmt .Errorf ("unable to load AWS Config: %w" , err )
59
+ },
60
+ }
61
+ ret .bedrockRegion = ret .PluginBase .AddSetupQuestion ("AWS Region" , true )
62
+ return
42
63
}
43
64
65
+ cfg .APIOptions = append (cfg .APIOptions , middleware .AddUserAgentKeyValue (userAgentKey , userAgentValue ))
66
+
44
67
runtimeClient := bedrockruntime .NewFromConfig (cfg )
45
68
controlPlaneClient := bedrock .NewFromConfig (cfg )
46
69
47
- ret = & BedrockClient {
48
- PluginBase : & plugins.PluginBase {
49
- Name : vendorName ,
50
- EnvNamePrefix : plugins .BuildEnvVariablePrefix (vendorName ),
51
- },
52
- runtimeClient : runtimeClient ,
53
- controlPlaneClient : controlPlaneClient ,
70
+ ret .PluginBase = & plugins.PluginBase {
71
+ Name : vendorName ,
72
+ EnvNamePrefix : plugins .BuildEnvVariablePrefix (vendorName ),
73
+ ConfigureCustom : ret .configure ,
74
+ }
75
+
76
+ ret .runtimeClient = runtimeClient
77
+ ret .controlPlaneClient = controlPlaneClient
78
+
79
+ ret .bedrockRegion = ret .PluginBase .AddSetupQuestion ("AWS Region" , true )
80
+
81
+ if cfg .Region != "" {
82
+ ret .bedrockRegion .Value = cfg .Region
54
83
}
55
84
56
85
return
57
86
}
58
87
59
- // ListModels lists the models available for use with the Bedrock plugin
88
+ // isValidAWSRegion validates AWS region format
89
+ func isValidAWSRegion (region string ) bool {
90
+ // Simple validation - AWS regions are typically 2-3 parts separated by hyphens
91
+ // Examples: us-east-1, eu-west-1, ap-southeast-2
92
+ if len (region ) < 5 || len (region ) > 30 {
93
+ return false
94
+ }
95
+ // Basic pattern check for AWS region format
96
+ return region != ""
97
+ }
98
+
99
+ // configure initializes the Bedrock clients with the specified AWS region.
100
+ // If no region is specified, the default region from AWS config is used.
101
+ func (c * BedrockClient ) configure () error {
102
+ if c .bedrockRegion .Value == "" {
103
+ return nil // Use default region from AWS config
104
+ }
105
+
106
+ // Validate region format
107
+ if ! isValidAWSRegion (c .bedrockRegion .Value ) {
108
+ return fmt .Errorf ("invalid AWS region: %s" , c .bedrockRegion .Value )
109
+ }
110
+
111
+ ctx := context .Background ()
112
+ cfg , err := config .LoadDefaultConfig (ctx , config .WithRegion (c .bedrockRegion .Value ))
113
+ if err != nil {
114
+ return fmt .Errorf ("unable to load AWS Config with region %s: %w" , c .bedrockRegion .Value , err )
115
+ }
116
+
117
+ cfg .APIOptions = append (cfg .APIOptions , middleware .AddUserAgentKeyValue (userAgentKey , userAgentValue ))
118
+
119
+ c .runtimeClient = bedrockruntime .NewFromConfig (cfg )
120
+ c .controlPlaneClient = bedrock .NewFromConfig (cfg )
121
+
122
+ return nil
123
+ }
124
+
125
+ // ListModels retrieves all available foundation models and inference profiles
126
+ // from AWS Bedrock that can be used with this plugin.
60
127
func (c * BedrockClient ) ListModels () ([]string , error ) {
61
128
models := []string {}
62
- ctx := context .TODO ()
129
+ ctx := context .Background ()
63
130
64
131
foundationModels , err := c .controlPlaneClient .ListFoundationModels (ctx , & bedrock.ListFoundationModelsInput {})
65
132
if err != nil {
66
- return nil , err
133
+ return nil , fmt . Errorf ( "failed to list foundation models: %w" , err )
67
134
}
68
135
69
136
for _ , model := range foundationModels .ModelSummaries {
@@ -73,9 +140,9 @@ func (c *BedrockClient) ListModels() ([]string, error) {
73
140
inferenceProfilesPaginator := bedrock .NewListInferenceProfilesPaginator (c .controlPlaneClient , & bedrock.ListInferenceProfilesInput {})
74
141
75
142
for inferenceProfilesPaginator .HasMorePages () {
76
- inferenceProfiles , err := inferenceProfilesPaginator .NextPage (context . TODO () )
143
+ inferenceProfiles , err := inferenceProfilesPaginator .NextPage (ctx )
77
144
if err != nil {
78
- return nil , err
145
+ return nil , fmt . Errorf ( "failed to list inference profiles: %w" , err )
79
146
}
80
147
81
148
for _ , profile := range inferenceProfiles .InferenceProfileSummaries {
@@ -88,6 +155,13 @@ func (c *BedrockClient) ListModels() ([]string, error) {
88
155
89
156
// SendStream sends the messages to the the Bedrock ConverseStream API
90
157
func (c * BedrockClient ) SendStream (msgs []* goopenai.ChatCompletionMessage , opts * common.ChatOptions , channel chan string ) (err error ) {
158
+ // Ensure channel is closed on all exit paths to prevent goroutine leaks
159
+ defer func () {
160
+ if r := recover (); r != nil {
161
+ err = fmt .Errorf ("panic in SendStream: %v" , r )
162
+ }
163
+ close (channel )
164
+ }()
91
165
92
166
messages := c .toMessages (msgs )
93
167
@@ -99,10 +173,9 @@ func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts
99
173
TopP : aws .Float32 (float32 (opts .TopP ))},
100
174
}
101
175
102
- response , err := c .runtimeClient .ConverseStream (context .TODO (), & converseInput )
176
+ response , err := c .runtimeClient .ConverseStream (context .Background (), & converseInput )
103
177
if err != nil {
104
- fmt .Printf ("Error conversing with Bedrock: %s\n " , err )
105
- return
178
+ return fmt .Errorf ("bedrock conversestream failed for model %s: %w" , opts .Model , err )
106
179
}
107
180
108
181
for event := range response .GetStream ().Events () {
@@ -118,7 +191,7 @@ func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts
118
191
119
192
case * types.ConverseStreamOutputMemberMessageStop :
120
193
channel <- "\n "
121
- close ( channel )
194
+ return nil // Let defer handle the close
122
195
123
196
// Unused Events
124
197
case * types.ConverseStreamOutputMemberMessageStart ,
@@ -127,7 +200,7 @@ func (c *BedrockClient) SendStream(msgs []*goopenai.ChatCompletionMessage, opts
127
200
* types.ConverseStreamOutputMemberMetadata :
128
201
129
202
default :
130
- fmt .Printf ( "Error: Unknown stream event type: %T\n " , v )
203
+ return fmt .Errorf ( "unknown stream event type: %T" , v )
131
204
}
132
205
}
133
206
@@ -145,22 +218,35 @@ func (c *BedrockClient) Send(ctx context.Context, msgs []*goopenai.ChatCompletio
145
218
}
146
219
response , err := c .runtimeClient .Converse (ctx , & converseInput )
147
220
if err != nil {
148
- fmt .Printf ("Error conversing with Bedrock: %s\n " , err )
149
- return "" , err
221
+ return "" , fmt .Errorf ("bedrock converse failed for model %s: %w" , opts .Model , err )
222
+ }
223
+
224
+ responseText , ok := response .Output .(* types.ConverseOutputMemberMessage )
225
+ if ! ok {
226
+ return "" , fmt .Errorf ("unexpected response type: %T" , response .Output )
227
+ }
228
+
229
+ if len (responseText .Value .Content ) == 0 {
230
+ return "" , fmt .Errorf ("empty response content" )
150
231
}
151
232
152
- responseText , _ := response .Output .(* types.ConverseOutputMemberMessage )
153
233
responseContentBlock := responseText .Value .Content [0 ]
154
- text , _ := responseContentBlock .(* types.ContentBlockMemberText )
234
+ text , ok := responseContentBlock .(* types.ContentBlockMemberText )
235
+ if ! ok {
236
+ return "" , fmt .Errorf ("unexpected content block type: %T" , responseContentBlock )
237
+ }
238
+
155
239
return text .Value , nil
156
240
}
157
241
242
+ // NeedsRawMode indicates whether the model requires raw mode processing.
243
+ // Bedrock models do not require raw mode.
158
244
func (c * BedrockClient ) NeedsRawMode (modelName string ) bool {
159
245
return false
160
246
}
161
247
162
248
// toMessages converts the array of input messages from the ChatCompletionMessageType to the
163
- // Bedrock Converse Message type
249
+ // Bedrock Converse Message type.
164
250
// The system role messages are mapped to the user role as they contain a mix of system messages,
165
251
// pattern content and user input.
166
252
func (c * BedrockClient ) toMessages (inputMessages []* goopenai.ChatCompletionMessage ) (messages []types.Message ) {
0 commit comments