Skip to content

Commit 4e16bbc

Browse files
authored
Merge pull request #1574 from ksylvan/0704-image-tool-model-validation
Add Model Validation for Image Generation and Fix CLI Flag Mapping
2 parents 5858311 + 60174f4 commit 4e16bbc

File tree

5 files changed

+146
-0
lines changed

5 files changed

+146
-0
lines changed

cli/flags.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions, err error) {
289289
}
290290

291291
ret = &common.ChatOptions{
292+
Model: o.Model,
292293
Temperature: o.Temperature,
293294
TopP: o.TopP,
294295
PresencePenalty: o.PresencePenalty,

mars-colony.png

-1.79 MB
Binary file not shown.

plugins/ai/openai/openai.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
128128
}
129129

130130
func (o *Client) sendResponses(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *common.ChatOptions) (ret string, err error) {
131+
// Validate model supports image generation if image file is specified
132+
if opts.ImageFile != "" && !supportsImageGeneration(opts.Model) {
133+
return "", fmt.Errorf("model '%s' does not support image generation. Supported models: %s", opts.Model, strings.Join(ImageGenerationSupportedModels, ", "))
134+
}
135+
131136
req := o.buildResponseParams(msgs, opts)
132137

133138
var resp *responses.Response

plugins/ai/openai/openai_image.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,26 @@ import (
1818
const ImageGenerationResponseType = "image_generation_call"
1919
const ImageGenerationToolType = "image_generation"
2020

21+
// ImageGenerationSupportedModels lists all models that support image generation
22+
var ImageGenerationSupportedModels = []string{
23+
"gpt-4o",
24+
"gpt-4o-mini",
25+
"gpt-4.1",
26+
"gpt-4.1-mini",
27+
"gpt-4.1-nano",
28+
"o3",
29+
}
30+
31+
// supportsImageGeneration checks if the given model supports the image_generation tool
32+
func supportsImageGeneration(model string) bool {
33+
for _, supportedModel := range ImageGenerationSupportedModels {
34+
if model == supportedModel {
35+
return true
36+
}
37+
}
38+
return false
39+
}
40+
2141
// getOutputFormatFromExtension determines the API output format based on file extension
2242
func getOutputFormatFromExtension(imagePath string) string {
2343
if imagePath == "" {

plugins/ai/openai/openai_image_test.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package openai
22

33
import (
4+
"fmt"
5+
"strings"
46
"testing"
57

68
"github.com/danielmiessler/fabric/chat"
@@ -218,3 +220,121 @@ func TestAddImageGenerationToolWithDynamicFormat(t *testing.T) {
218220
})
219221
}
220222
}
223+
224+
func TestSupportsImageGeneration(t *testing.T) {
225+
tests := []struct {
226+
name string
227+
model string
228+
expected bool
229+
}{
230+
{
231+
name: "gpt-4o supports image generation",
232+
model: "gpt-4o",
233+
expected: true,
234+
},
235+
{
236+
name: "gpt-4o-mini supports image generation",
237+
model: "gpt-4o-mini",
238+
expected: true,
239+
},
240+
{
241+
name: "gpt-4.1 supports image generation",
242+
model: "gpt-4.1",
243+
expected: true,
244+
},
245+
{
246+
name: "gpt-4.1-mini supports image generation",
247+
model: "gpt-4.1-mini",
248+
expected: true,
249+
},
250+
{
251+
name: "gpt-4.1-nano supports image generation",
252+
model: "gpt-4.1-nano",
253+
expected: true,
254+
},
255+
{
256+
name: "o3 supports image generation",
257+
model: "o3",
258+
expected: true,
259+
},
260+
{
261+
name: "o1 does not support image generation",
262+
model: "o1",
263+
expected: false,
264+
},
265+
{
266+
name: "o1-mini does not support image generation",
267+
model: "o1-mini",
268+
expected: false,
269+
},
270+
{
271+
name: "o3-mini does not support image generation",
272+
model: "o3-mini",
273+
expected: false,
274+
},
275+
{
276+
name: "gpt-4 does not support image generation",
277+
model: "gpt-4",
278+
expected: false,
279+
},
280+
{
281+
name: "gpt-3.5-turbo does not support image generation",
282+
model: "gpt-3.5-turbo",
283+
expected: false,
284+
},
285+
{
286+
name: "empty model does not support image generation",
287+
model: "",
288+
expected: false,
289+
},
290+
}
291+
292+
for _, tt := range tests {
293+
t.Run(tt.name, func(t *testing.T) {
294+
result := supportsImageGeneration(tt.model)
295+
assert.Equal(t, tt.expected, result)
296+
})
297+
}
298+
}
299+
300+
func TestModelValidationLogic(t *testing.T) {
301+
t.Run("Unsupported model with image file should return validation error", func(t *testing.T) {
302+
opts := &common.ChatOptions{
303+
Model: "o1-mini",
304+
ImageFile: "/tmp/output.png",
305+
}
306+
307+
// Test the validation logic directly
308+
if opts.ImageFile != "" && !supportsImageGeneration(opts.Model) {
309+
err := fmt.Errorf("model '%s' does not support image generation. Supported models: %s", opts.Model, strings.Join(ImageGenerationSupportedModels, ", "))
310+
311+
assert.Contains(t, err.Error(), "does not support image generation")
312+
assert.Contains(t, err.Error(), "o1-mini")
313+
assert.Contains(t, err.Error(), "Supported models:")
314+
} else {
315+
t.Error("Expected validation to trigger")
316+
}
317+
})
318+
319+
t.Run("Supported model with image file should not trigger validation", func(t *testing.T) {
320+
opts := &common.ChatOptions{
321+
Model: "gpt-4o",
322+
ImageFile: "/tmp/output.png",
323+
}
324+
325+
// Test the validation logic directly
326+
shouldFail := opts.ImageFile != "" && !supportsImageGeneration(opts.Model)
327+
assert.False(t, shouldFail, "Validation should not trigger for supported model")
328+
})
329+
330+
t.Run("Unsupported model without image file should not trigger validation", func(t *testing.T) {
331+
opts := &common.ChatOptions{
332+
Model: "o1-mini",
333+
ImageFile: "", // No image file
334+
}
335+
336+
// Test the validation logic directly
337+
shouldFail := opts.ImageFile != "" && !supportsImageGeneration(opts.Model)
338+
assert.False(t, shouldFail, "Validation should not trigger when no image file is specified")
339+
})
340+
}

0 commit comments

Comments
 (0)