Skip to content

Commit e198347

Browse files
mudlerdave-gray101
andauthored
feat(openai): add json_schema format type and strict mode (#3193)
* feat(openai): add json_schema and strict mode Signed-off-by: Ettore Di Giacinto <[email protected]> * handle err vs _ security scanners prefer if we put these branches in, and I tend to agree. Signed-off-by: Dave <[email protected]> --------- Signed-off-by: Ettore Di Giacinto <[email protected]> Signed-off-by: Dave <[email protected]> Co-authored-by: Dave <[email protected]>
1 parent 66cf38b commit e198347

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

core/http/endpoints/openai/chat.go

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,14 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
172172

173173
funcs := input.Functions
174174
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
175+
strictMode := false
176+
177+
for _, f := range input.Functions {
178+
if f.Strict {
179+
strictMode = true
180+
break
181+
}
182+
}
175183

176184
// Allow the user to set custom actions via config file
177185
// to be "embedded" in each model
@@ -187,10 +195,33 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
187195

188196
if config.ResponseFormatMap != nil {
189197
d := schema.ChatCompletionResponseFormat{}
190-
dat, _ := json.Marshal(config.ResponseFormatMap)
191-
_ = json.Unmarshal(dat, &d)
198+
dat, err := json.Marshal(config.ResponseFormatMap)
199+
if err != nil {
200+
return err
201+
}
202+
err = json.Unmarshal(dat, &d)
203+
if err != nil {
204+
return err
205+
}
192206
if d.Type == "json_object" {
193207
input.Grammar = functions.JSONBNF
208+
} else if d.Type == "json_schema" {
209+
d := schema.JsonSchemaRequest{}
210+
dat, err := json.Marshal(config.ResponseFormatMap)
211+
if err != nil {
212+
return err
213+
}
214+
err = json.Unmarshal(dat, &d)
215+
if err != nil {
216+
return err
217+
}
218+
fs := &functions.JSONFunctionStructure{
219+
AnyOf: []functions.Item{d.JsonSchema.Schema},
220+
}
221+
g, err := fs.Grammar(config.FunctionsConfig.GrammarOptions()...)
222+
if err == nil {
223+
input.Grammar = g
224+
}
194225
}
195226
}
196227

@@ -201,7 +232,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
201232
}
202233

203234
switch {
204-
case !config.FunctionsConfig.GrammarConfig.NoGrammar && shouldUseFn:
235+
case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn:
205236
noActionGrammar := functions.Function{
206237
Name: noActionName,
207238
Description: noActionDescription,

core/schema/openai.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,17 @@ type ChatCompletionResponseFormat struct {
139139
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
140140
}
141141

142+
type JsonSchemaRequest struct {
143+
Type string `json:"type"`
144+
JsonSchema JsonSchema `json:"json_schema"`
145+
}
146+
147+
type JsonSchema struct {
148+
Name string `json:"name"`
149+
Strict bool `json:"strict"`
150+
Schema functions.Item `json:"schema"`
151+
}
152+
142153
type OpenAIRequest struct {
143154
PredictionOptions
144155

pkg/functions/functions.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ const (
1414
type Function struct {
1515
Name string `json:"name"`
1616
Description string `json:"description"`
17+
Strict bool `json:"strict"`
1718
Parameters map[string]interface{} `json:"parameters"`
1819
}
1920
type Functions []Function

0 commit comments

Comments
 (0)