@@ -41,7 +41,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
41
41
}
42
42
responses <- initialMessage
43
43
44
- ComputeChoices (req , s , config , startupOptions , loader , func (s string , c * []schema.Choice ) {}, func (s string , tokenUsage backend.TokenUsage ) bool {
44
+ ComputeChoices (req , s , config , cl , startupOptions , loader , func (s string , c * []schema.Choice ) {}, func (s string , tokenUsage backend.TokenUsage ) bool {
45
45
usage := schema.OpenAIUsage {
46
46
PromptTokens : tokenUsage .Prompt ,
47
47
CompletionTokens : tokenUsage .Completion ,
@@ -68,7 +68,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
68
68
}
69
69
processTools := func (noAction string , prompt string , req * schema.OpenAIRequest , config * config.BackendConfig , loader * model.ModelLoader , responses chan schema.OpenAIResponse , extraUsage bool ) {
70
70
result := ""
71
- _ , tokenUsage , _ := ComputeChoices (req , prompt , config , startupOptions , loader , func (s string , c * []schema.Choice ) {}, func (s string , usage backend.TokenUsage ) bool {
71
+ _ , tokenUsage , _ := ComputeChoices (req , prompt , config , cl , startupOptions , loader , func (s string , c * []schema.Choice ) {}, func (s string , usage backend.TokenUsage ) bool {
72
72
result += s
73
73
// TODO: Change generated BNF grammar to be compliant with the schema so we can
74
74
// stream the result token by token here.
@@ -92,7 +92,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
92
92
}
93
93
responses <- initialMessage
94
94
95
- result , err := handleQuestion (config , req , ml , startupOptions , functionResults , result , prompt )
95
+ result , err := handleQuestion (config , cl , req , ml , startupOptions , functionResults , result , prompt )
96
96
if err != nil {
97
97
log .Error ().Err (err ).Msg ("error handling question" )
98
98
return
@@ -383,7 +383,8 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
383
383
384
384
// no streaming mode
385
385
default :
386
- result , tokenUsage , err := ComputeChoices (input , predInput , config , startupOptions , ml , func (s string , c * []schema.Choice ) {
386
+
387
+ tokenCallback := func (s string , c * []schema.Choice ) {
387
388
if ! shouldUseFn {
388
389
// no function is called, just reply and use stop as finish reason
389
390
* c = append (* c , schema.Choice {FinishReason : "stop" , Index : 0 , Message : & schema.Message {Role : "assistant" , Content : & s }})
@@ -403,7 +404,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
403
404
404
405
switch {
405
406
case noActionsToRun :
406
- result , err := handleQuestion (config , input , ml , startupOptions , results , s , predInput )
407
+ result , err := handleQuestion (config , cl , input , ml , startupOptions , results , s , predInput )
407
408
if err != nil {
408
409
log .Error ().Err (err ).Msg ("error handling question" )
409
410
return
@@ -458,7 +459,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
458
459
}
459
460
}
460
461
461
- }, nil )
462
+ }
463
+
464
+ result , tokenUsage , err := ComputeChoices (
465
+ input ,
466
+ predInput ,
467
+ config ,
468
+ cl ,
469
+ startupOptions ,
470
+ ml ,
471
+ tokenCallback ,
472
+ nil ,
473
+ )
462
474
if err != nil {
463
475
return err
464
476
}
@@ -489,7 +501,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
489
501
}
490
502
}
491
503
492
- func handleQuestion (config * config.BackendConfig , input * schema.OpenAIRequest , ml * model.ModelLoader , o * config.ApplicationConfig , funcResults []functions.FuncCallResults , result , prompt string ) (string , error ) {
504
+ func handleQuestion (config * config.BackendConfig , cl * config. BackendConfigLoader , input * schema.OpenAIRequest , ml * model.ModelLoader , o * config.ApplicationConfig , funcResults []functions.FuncCallResults , result , prompt string ) (string , error ) {
493
505
494
506
if len (funcResults ) == 0 && result != "" {
495
507
log .Debug ().Msgf ("nothing function results but we had a message from the LLM" )
@@ -538,7 +550,7 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
538
550
audios = append (audios , m .StringAudios ... )
539
551
}
540
552
541
- predFunc , err := backend .ModelInference (input .Context , prompt , input .Messages , images , videos , audios , ml , config , o , nil )
553
+ predFunc , err := backend .ModelInference (input .Context , prompt , input .Messages , images , videos , audios , ml , config , cl , o , nil )
542
554
if err != nil {
543
555
log .Error ().Err (err ).Msg ("model inference failed" )
544
556
return "" , err
0 commit comments