Skip to content

Commit 33f9ee0

Browse files
authored
fix(gallery): automatically install model from name (#5757)
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent c546774 commit 33f9ee0

File tree

5 files changed

+36
-17
lines changed

5 files changed

+36
-17
lines changed

core/backend/llm.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7-
"os"
87
"regexp"
8+
"slices"
99
"strings"
1010
"sync"
1111
"unicode/utf8"
@@ -14,6 +14,7 @@ import (
1414

1515
"github.com/mudler/LocalAI/core/config"
1616
"github.com/mudler/LocalAI/core/schema"
17+
"github.com/mudler/LocalAI/core/services"
1718

1819
"github.com/mudler/LocalAI/core/gallery"
1920
"github.com/mudler/LocalAI/pkg/grpc/proto"
@@ -34,15 +35,19 @@ type TokenUsage struct {
3435
TimingTokenGeneration float64
3536
}
3637

37-
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
38+
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, cl *config.BackendConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
3839
modelFile := c.Model
3940

4041
// Check if the modelFile exists, if it doesn't try to load it from the gallery
4142
if o.AutoloadGalleries { // experimental
42-
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
43+
modelNames, err := services.ListModels(cl, loader, nil, services.SKIP_ALWAYS)
44+
if err != nil {
45+
return nil, err
46+
}
47+
if !slices.Contains(modelNames, c.Name) {
4348
utils.ResetDownloadTimers()
4449
// if we failed to load the model, we try to download it
45-
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, modelFile, loader.ModelPath, o.BackendsPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
50+
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, c.Name, loader.ModelPath, o.BackendsPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
4651
if err != nil {
4752
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
4853
//return nil, err

core/http/endpoints/openai/chat.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
4141
}
4242
responses <- initialMessage
4343

44-
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
44+
ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
4545
usage := schema.OpenAIUsage{
4646
PromptTokens: tokenUsage.Prompt,
4747
CompletionTokens: tokenUsage.Completion,
@@ -68,7 +68,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
6868
}
6969
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
7070
result := ""
71-
_, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
71+
_, tokenUsage, _ := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
7272
result += s
7373
// TODO: Change generated BNF grammar to be compliant with the schema so we can
7474
// stream the result token by token here.
@@ -92,7 +92,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
9292
}
9393
responses <- initialMessage
9494

95-
result, err := handleQuestion(config, req, ml, startupOptions, functionResults, result, prompt)
95+
result, err := handleQuestion(config, cl, req, ml, startupOptions, functionResults, result, prompt)
9696
if err != nil {
9797
log.Error().Err(err).Msg("error handling question")
9898
return
@@ -383,7 +383,8 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
383383

384384
// no streaming mode
385385
default:
386-
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
386+
387+
tokenCallback := func(s string, c *[]schema.Choice) {
387388
if !shouldUseFn {
388389
// no function is called, just reply and use stop as finish reason
389390
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
@@ -403,7 +404,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
403404

404405
switch {
405406
case noActionsToRun:
406-
result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput)
407+
result, err := handleQuestion(config, cl, input, ml, startupOptions, results, s, predInput)
407408
if err != nil {
408409
log.Error().Err(err).Msg("error handling question")
409410
return
@@ -458,7 +459,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
458459
}
459460
}
460461

461-
}, nil)
462+
}
463+
464+
result, tokenUsage, err := ComputeChoices(
465+
input,
466+
predInput,
467+
config,
468+
cl,
469+
startupOptions,
470+
ml,
471+
tokenCallback,
472+
nil,
473+
)
462474
if err != nil {
463475
return err
464476
}
@@ -489,7 +501,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
489501
}
490502
}
491503

492-
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
504+
func handleQuestion(config *config.BackendConfig, cl *config.BackendConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
493505

494506
if len(funcResults) == 0 && result != "" {
495507
log.Debug().Msgf("nothing function results but we had a message from the LLM")
@@ -538,7 +550,7 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
538550
audios = append(audios, m.StringAudios...)
539551
}
540552

541-
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, o, nil)
553+
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil)
542554
if err != nil {
543555
log.Error().Err(err).Msg("model inference failed")
544556
return "", err

core/http/endpoints/openai/completion.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
3131
created := int(time.Now().Unix())
3232

3333
process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
34-
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
34+
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
3535
usage := schema.OpenAIUsage{
3636
PromptTokens: tokenUsage.Prompt,
3737
CompletionTokens: tokenUsage.Completion,
@@ -58,7 +58,8 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
5858

5959
responses <- resp
6060
return true
61-
})
61+
}
62+
ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback)
6263
close(responses)
6364
}
6465

@@ -168,7 +169,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
168169
}
169170

170171
r, tokenUsage, err := ComputeChoices(
171-
input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
172+
input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
172173
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
173174
}, nil)
174175
if err != nil {

core/http/endpoints/openai/edit.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
5656
log.Debug().Msgf("Template found, input modified to: %s", i)
5757
}
5858

59-
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
59+
r, tokenUsage, err := ComputeChoices(input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
6060
*c = append(*c, schema.Choice{Text: s})
6161
}, nil)
6262
if err != nil {

core/http/endpoints/openai/inference.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ func ComputeChoices(
1212
req *schema.OpenAIRequest,
1313
predInput string,
1414
config *config.BackendConfig,
15+
bcl *config.BackendConfigLoader,
1516
o *config.ApplicationConfig,
1617
loader *model.ModelLoader,
1718
cb func(string, *[]schema.Choice),
@@ -37,7 +38,7 @@ func ComputeChoices(
3738
}
3839

3940
// get the model function to call for the result
40-
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, o, tokenCallback)
41+
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback)
4142
if err != nil {
4243
return result, backend.TokenUsage{}, err
4344
}

0 commit comments

Comments
 (0)