Skip to content

Commit 36524cd

Browse files
authored
Merge pull request #1573 from ksylvan/0704-image-dynamic-formats
Add Image File Validation and Dynamic Format Support
2 parents 1eac026 + e59156a commit 36524cd

File tree

10 files changed

+262
-10
lines changed

10 files changed

+262
-10
lines changed

cli/cli.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,11 @@ func Cli(version string) (err error) {
270270
if chatReq.Language == "" {
271271
chatReq.Language = registry.Language.DefaultLanguage.Value
272272
}
273-
if session, err = chatter.Send(chatReq, currentFlags.BuildChatOptions()); err != nil {
273+
var chatOptions *common.ChatOptions
274+
if chatOptions, err = currentFlags.BuildChatOptions(); err != nil {
275+
return
276+
}
277+
if session, err = chatter.Send(chatReq, chatOptions); err != nil {
274278
return
275279
}
276280

cli/flags.go

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"os"
9+
"path/filepath"
910
"reflect"
1011
"strconv"
1112
"strings"
@@ -14,7 +15,7 @@ import (
1415
"github.com/danielmiessler/fabric/common"
1516
"github.com/jessevdk/go-flags"
1617
"golang.org/x/text/language"
17-
"gopkg.in/yaml.v2"
18+
"gopkg.in/yaml.v3"
1819
)
1920

2021
// Flags create flags struct. the users flags go into this, this will be passed to the chat struct in cli
@@ -257,7 +258,36 @@ func readStdin() (ret string, err error) {
257258
return
258259
}
259260

260-
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
261+
// validateImageFile validates the image file path and extension
262+
func validateImageFile(imagePath string) error {
263+
if imagePath == "" {
264+
return nil // No validation needed if no image file specified
265+
}
266+
267+
// Check if file already exists
268+
if _, err := os.Stat(imagePath); err == nil {
269+
return fmt.Errorf("image file already exists: %s", imagePath)
270+
}
271+
272+
// Check file extension
273+
ext := strings.ToLower(filepath.Ext(imagePath))
274+
validExtensions := []string{".png", ".jpeg", ".jpg", ".webp"}
275+
276+
for _, validExt := range validExtensions {
277+
if ext == validExt {
278+
return nil // Valid extension found
279+
}
280+
}
281+
282+
return fmt.Errorf("invalid image file extension '%s'. Supported formats: .png, .jpeg, .jpg, .webp", ext)
283+
}
284+
285+
func (o *Flags) BuildChatOptions() (ret *common.ChatOptions, err error) {
286+
// Validate image file if specified
287+
if err = validateImageFile(o.ImageFile); err != nil {
288+
return nil, err
289+
}
290+
261291
ret = &common.ChatOptions{
262292
Temperature: o.Temperature,
263293
TopP: o.TopP,

cli/flags_test.go

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"io"
66
"os"
7+
"path/filepath"
78
"strings"
89
"testing"
910

@@ -64,7 +65,8 @@ func TestBuildChatOptions(t *testing.T) {
6465
Raw: false,
6566
Seed: 1,
6667
}
67-
options := flags.BuildChatOptions()
68+
options, err := flags.BuildChatOptions()
69+
assert.NoError(t, err)
6870
assert.Equal(t, expectedOptions, options)
6971
}
7072

@@ -84,7 +86,8 @@ func TestBuildChatOptionsDefaultSeed(t *testing.T) {
8486
Raw: false,
8587
Seed: 0,
8688
}
87-
options := flags.BuildChatOptions()
89+
options, err := flags.BuildChatOptions()
90+
assert.NoError(t, err)
8891
assert.Equal(t, expectedOptions, options)
8992
}
9093

@@ -164,3 +167,91 @@ model: 123 # should be string
164167
assert.Error(t, err)
165168
})
166169
}
170+
171+
func TestValidateImageFile(t *testing.T) {
172+
t.Run("Empty path should be valid", func(t *testing.T) {
173+
err := validateImageFile("")
174+
assert.NoError(t, err)
175+
})
176+
177+
t.Run("Valid extensions should pass", func(t *testing.T) {
178+
validExtensions := []string{".png", ".jpeg", ".jpg", ".webp"}
179+
for _, ext := range validExtensions {
180+
filename := "/tmp/test" + ext
181+
err := validateImageFile(filename)
182+
assert.NoError(t, err, "Extension %s should be valid", ext)
183+
}
184+
})
185+
186+
t.Run("Invalid extensions should fail", func(t *testing.T) {
187+
invalidExtensions := []string{".gif", ".bmp", ".tiff", ".svg", ".txt", ""}
188+
for _, ext := range invalidExtensions {
189+
filename := "/tmp/test" + ext
190+
err := validateImageFile(filename)
191+
assert.Error(t, err, "Extension %s should be invalid", ext)
192+
assert.Contains(t, err.Error(), "invalid image file extension")
193+
}
194+
})
195+
196+
t.Run("Existing file should fail", func(t *testing.T) {
197+
// Create a temporary file
198+
tempFile, err := os.CreateTemp("", "test*.png")
199+
assert.NoError(t, err)
200+
defer os.Remove(tempFile.Name())
201+
tempFile.Close()
202+
203+
// Validation should fail because file exists
204+
err = validateImageFile(tempFile.Name())
205+
assert.Error(t, err)
206+
assert.Contains(t, err.Error(), "image file already exists")
207+
})
208+
209+
t.Run("Non-existing file with valid extension should pass", func(t *testing.T) {
210+
nonExistentFile := filepath.Join(os.TempDir(), "non_existent_file.png")
211+
// Make sure the file doesn't exist
212+
os.Remove(nonExistentFile)
213+
214+
err := validateImageFile(nonExistentFile)
215+
assert.NoError(t, err)
216+
})
217+
}
218+
219+
func TestBuildChatOptionsWithImageFileValidation(t *testing.T) {
220+
t.Run("Valid image file should pass", func(t *testing.T) {
221+
flags := &Flags{
222+
ImageFile: "/tmp/output.png",
223+
}
224+
225+
options, err := flags.BuildChatOptions()
226+
assert.NoError(t, err)
227+
assert.Equal(t, "/tmp/output.png", options.ImageFile)
228+
})
229+
230+
t.Run("Invalid extension should fail", func(t *testing.T) {
231+
flags := &Flags{
232+
ImageFile: "/tmp/output.gif",
233+
}
234+
235+
options, err := flags.BuildChatOptions()
236+
assert.Error(t, err)
237+
assert.Nil(t, options)
238+
assert.Contains(t, err.Error(), "invalid image file extension")
239+
})
240+
241+
t.Run("Existing file should fail", func(t *testing.T) {
242+
// Create a temporary file
243+
tempFile, err := os.CreateTemp("", "existing*.png")
244+
assert.NoError(t, err)
245+
defer os.Remove(tempFile.Name())
246+
tempFile.Close()
247+
248+
flags := &Flags{
249+
ImageFile: tempFile.Name(),
250+
}
251+
252+
options, err := flags.BuildChatOptions()
253+
assert.Error(t, err)
254+
assert.Nil(t, options)
255+
assert.Contains(t, err.Error(), "image file already exists")
256+
})
257+
}

completions/_fabric

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ _fabric() {
9898
'(--version)--version[Print current version]' \
9999
'(--search)--search[Enable web search tool for supported models (Anthropic, OpenAI)]' \
100100
'(--search-location)--search-location[Set location for web search results]:location:' \
101-
'(--image-file)--image-file[Save generated image to specified file path]:image file:_files -g "*.png *.jpg *.jpeg *.gif *.bmp"' \
101+
'(--image-file)--image-file[Save generated image to specified file path]:image file:_files -g "*.png *.webp *.jpeg *.jpg"' \
102102
'(--listextensions)--listextensions[List all registered extensions]' \
103103
'(--addextension)--addextension[Register a new extension from config file path]:config file:_files -g "*.yaml *.yml"' \
104104
'(--rmextension)--rmextension[Remove a registered extension by name]:extension:_fabric_extensions' \

completions/fabric.fish

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ complete -c fabric -l address -d "The address to bind the REST API (default: :80
6161
complete -c fabric -l api-key -d "API key used to secure server routes"
6262
complete -c fabric -l config -d "Path to YAML config file" -r -a "*.yaml *.yml"
6363
complete -c fabric -l search-location -d "Set location for web search results (e.g., 'America/Los_Angeles')"
64-
complete -c fabric -l image-file -d "Save generated image to specified file path (e.g., 'output.png')" -r -a "*.png *.jpg *.jpeg *.gif *.bmp"
64+
complete -c fabric -l image-file -d "Save generated image to specified file path (e.g., 'output.png')" -r -a "*.png *.webp *.jpeg *.jpg"
6565
complete -c fabric -l addextension -d "Register a new extension from config file path" -r -a "*.yaml *.yml"
6666
complete -c fabric -l rmextension -d "Remove a registered extension by name" -a "(__fabric_get_extensions)"
6767
complete -c fabric -l strategy -d "Choose a strategy from the available strategies" -a "(__fabric_get_strategies)"

go.mod

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ require (
2727
github.com/stretchr/testify v1.10.0
2828
golang.org/x/text v0.26.0
2929
google.golang.org/api v0.236.0
30-
gopkg.in/yaml.v2 v2.4.0
3130
gopkg.in/yaml.v3 v3.0.1
3231
)
3332

go.sum

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
354354
gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
355355
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
356356
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
357-
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
358357
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
359358
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
360359
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

mars-colony.png

1.79 MB
Loading

plugins/ai/openai/openai_image.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"os"
1010
"path/filepath"
11+
"strings"
1112

1213
"github.com/danielmiessler/fabric/common"
1314
"github.com/openai/openai-go/responses"
@@ -17,15 +18,37 @@ import (
1718
const ImageGenerationResponseType = "image_generation_call"
1819
const ImageGenerationToolType = "image_generation"
1920

21+
// getOutputFormatFromExtension determines the API output format based on file extension
22+
func getOutputFormatFromExtension(imagePath string) string {
23+
if imagePath == "" {
24+
return "png" // Default format
25+
}
26+
27+
ext := strings.ToLower(filepath.Ext(imagePath))
28+
switch ext {
29+
case ".png":
30+
return "png"
31+
case ".webp":
32+
return "webp"
33+
case ".jpg":
34+
return "jpeg"
35+
case ".jpeg":
36+
return "jpeg"
37+
default:
38+
return "png" // Default fallback
39+
}
40+
}
41+
2042
// addImageGenerationTool adds the image generation tool to the request if needed
2143
func (o *Client) addImageGenerationTool(opts *common.ChatOptions, tools []responses.ToolUnionParam) []responses.ToolUnionParam {
2244
// Check if the request seems to be asking for image generation
2345
if o.shouldUseImageGeneration(opts) {
46+
outputFormat := getOutputFormatFromExtension(opts.ImageFile)
2447
imageGenTool := responses.ToolUnionParam{
2548
OfImageGeneration: &responses.ToolImageGenerationParam{
2649
Type: ImageGenerationToolType,
2750
Model: "gpt-image-1",
28-
OutputFormat: "png",
51+
OutputFormat: outputFormat,
2952
Quality: "auto",
3053
Size: "auto",
3154
},

plugins/ai/openai/openai_image_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,109 @@ func TestBuildResponseParams_WithBothSearchAndImage(t *testing.T) {
112112
assert.True(t, hasSearchTool, "Should have web search tool")
113113
assert.True(t, hasImageTool, "Should have image generation tool")
114114
}
115+
116+
func TestGetOutputFormatFromExtension(t *testing.T) {
117+
tests := []struct {
118+
name string
119+
imagePath string
120+
expectedFormat string
121+
}{
122+
{
123+
name: "PNG extension",
124+
imagePath: "/tmp/output.png",
125+
expectedFormat: "png",
126+
},
127+
{
128+
name: "WEBP extension",
129+
imagePath: "/tmp/output.webp",
130+
expectedFormat: "webp",
131+
},
132+
{
133+
name: "JPG extension",
134+
imagePath: "/tmp/output.jpg",
135+
expectedFormat: "jpeg",
136+
},
137+
{
138+
name: "JPEG extension",
139+
imagePath: "/tmp/output.jpeg",
140+
expectedFormat: "jpeg",
141+
},
142+
{
143+
name: "Uppercase PNG extension",
144+
imagePath: "/tmp/output.PNG",
145+
expectedFormat: "png",
146+
},
147+
{
148+
name: "Mixed case JPEG extension",
149+
imagePath: "/tmp/output.JpEg",
150+
expectedFormat: "jpeg",
151+
},
152+
{
153+
name: "Empty path",
154+
imagePath: "",
155+
expectedFormat: "png",
156+
},
157+
{
158+
name: "No extension",
159+
imagePath: "/tmp/output",
160+
expectedFormat: "png",
161+
},
162+
{
163+
name: "Unsupported extension",
164+
imagePath: "/tmp/output.gif",
165+
expectedFormat: "png",
166+
},
167+
}
168+
169+
for _, tt := range tests {
170+
t.Run(tt.name, func(t *testing.T) {
171+
result := getOutputFormatFromExtension(tt.imagePath)
172+
assert.Equal(t, tt.expectedFormat, result)
173+
})
174+
}
175+
}
176+
177+
func TestAddImageGenerationToolWithDynamicFormat(t *testing.T) {
178+
client := NewClient()
179+
180+
tests := []struct {
181+
name string
182+
imageFile string
183+
expectedFormat string
184+
}{
185+
{
186+
name: "PNG file",
187+
imageFile: "/tmp/output.png",
188+
expectedFormat: "png",
189+
},
190+
{
191+
name: "WEBP file",
192+
imageFile: "/tmp/output.webp",
193+
expectedFormat: "webp",
194+
},
195+
{
196+
name: "JPG file",
197+
imageFile: "/tmp/output.jpg",
198+
expectedFormat: "jpeg",
199+
},
200+
{
201+
name: "JPEG file",
202+
imageFile: "/tmp/output.jpeg",
203+
expectedFormat: "jpeg",
204+
},
205+
}
206+
207+
for _, tt := range tests {
208+
t.Run(tt.name, func(t *testing.T) {
209+
opts := &common.ChatOptions{
210+
ImageFile: tt.imageFile,
211+
}
212+
213+
tools := client.addImageGenerationTool(opts, []responses.ToolUnionParam{})
214+
215+
assert.Len(t, tools, 1, "Should have one tool")
216+
assert.NotNil(t, tools[0].OfImageGeneration, "Should be image generation tool")
217+
assert.Equal(t, tt.expectedFormat, tools[0].OfImageGeneration.OutputFormat, "Output format should match file extension")
218+
})
219+
}
220+
}

0 commit comments

Comments
 (0)