Skip to content

Commit 7f06954

Browse files
authored
fix(model-loading): keep track of open GRPC Clients (#3377)
Due to a previous refactor we moved the client constructor tight to the model address, however that was just a string which we would use to build the client each time. With this change we make the loader to return a *Model which carries a constructor for the client and stores the client on the first connection. Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 771a052 commit 7f06954

File tree

10 files changed

+176
-171
lines changed

10 files changed

+176
-171
lines changed

core/services/backend_monitor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.Status
107107
return nil, err
108108
}
109109
modelAddr := bms.modelLoader.CheckIsLoaded(backendId)
110-
if modelAddr == "" {
110+
if modelAddr == nil {
111111
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
112112
}
113113

pkg/grpc/backend.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool)
1818
if bc, ok := embeds[address]; ok {
1919
return bc
2020
}
21-
return NewGrpcClient(address, parallel, wd, enableWatchDog)
21+
return buildClient(address, parallel, wd, enableWatchDog)
2222
}
2323

24-
func NewGrpcClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend {
24+
func buildClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) Backend {
2525
if !enableWatchDog {
2626
wd = nil
2727
}

pkg/grpc/client.go

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ func (c *Client) setBusy(v bool) {
3939
c.Unlock()
4040
}
4141

42+
func (c *Client) wdMark() {
43+
if c.wd != nil {
44+
c.wd.Mark(c.address)
45+
}
46+
}
47+
48+
func (c *Client) wdUnMark() {
49+
if c.wd != nil {
50+
c.wd.UnMark(c.address)
51+
}
52+
}
53+
4254
func (c *Client) HealthCheck(ctx context.Context) (bool, error) {
4355
if !c.parallel {
4456
c.opMutex.Lock()
@@ -76,10 +88,8 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...
7688
}
7789
c.setBusy(true)
7890
defer c.setBusy(false)
79-
if c.wd != nil {
80-
c.wd.Mark(c.address)
81-
defer c.wd.UnMark(c.address)
82-
}
91+
c.wdMark()
92+
defer c.wdUnMark()
8393
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
8494
if err != nil {
8595
return nil, err
@@ -97,10 +107,8 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp
97107
}
98108
c.setBusy(true)
99109
defer c.setBusy(false)
100-
if c.wd != nil {
101-
c.wd.Mark(c.address)
102-
defer c.wd.UnMark(c.address)
103-
}
110+
c.wdMark()
111+
defer c.wdUnMark()
104112
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
105113
if err != nil {
106114
return nil, err
@@ -118,10 +126,8 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
118126
}
119127
c.setBusy(true)
120128
defer c.setBusy(false)
121-
if c.wd != nil {
122-
c.wd.Mark(c.address)
123-
defer c.wd.UnMark(c.address)
124-
}
129+
c.wdMark()
130+
defer c.wdUnMark()
125131
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
126132
if err != nil {
127133
return nil, err
@@ -138,10 +144,8 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
138144
}
139145
c.setBusy(true)
140146
defer c.setBusy(false)
141-
if c.wd != nil {
142-
c.wd.Mark(c.address)
143-
defer c.wd.UnMark(c.address)
144-
}
147+
c.wdMark()
148+
defer c.wdUnMark()
145149
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
146150
if err != nil {
147151
return err
@@ -177,10 +181,8 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
177181
}
178182
c.setBusy(true)
179183
defer c.setBusy(false)
180-
if c.wd != nil {
181-
c.wd.Mark(c.address)
182-
defer c.wd.UnMark(c.address)
183-
}
184+
c.wdMark()
185+
defer c.wdUnMark()
184186
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
185187
if err != nil {
186188
return nil, err
@@ -197,10 +199,8 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
197199
}
198200
c.setBusy(true)
199201
defer c.setBusy(false)
200-
if c.wd != nil {
201-
c.wd.Mark(c.address)
202-
defer c.wd.UnMark(c.address)
203-
}
202+
c.wdMark()
203+
defer c.wdUnMark()
204204
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
205205
if err != nil {
206206
return nil, err
@@ -217,10 +217,8 @@ func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequ
217217
}
218218
c.setBusy(true)
219219
defer c.setBusy(false)
220-
if c.wd != nil {
221-
c.wd.Mark(c.address)
222-
defer c.wd.UnMark(c.address)
223-
}
220+
c.wdMark()
221+
defer c.wdUnMark()
224222
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
225223
if err != nil {
226224
return nil, err
@@ -237,10 +235,8 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
237235
}
238236
c.setBusy(true)
239237
defer c.setBusy(false)
240-
if c.wd != nil {
241-
c.wd.Mark(c.address)
242-
defer c.wd.UnMark(c.address)
243-
}
238+
c.wdMark()
239+
defer c.wdUnMark()
244240
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
245241
if err != nil {
246242
return nil, err
@@ -277,10 +273,8 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts
277273
}
278274
c.setBusy(true)
279275
defer c.setBusy(false)
280-
if c.wd != nil {
281-
c.wd.Mark(c.address)
282-
defer c.wd.UnMark(c.address)
283-
}
276+
c.wdMark()
277+
defer c.wdUnMark()
284278
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
285279
if err != nil {
286280
return nil, err
@@ -319,6 +313,8 @@ func (c *Client) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ..
319313
}
320314
c.setBusy(true)
321315
defer c.setBusy(false)
316+
c.wdMark()
317+
defer c.wdUnMark()
322318
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
323319
if err != nil {
324320
return nil, err
@@ -333,6 +329,8 @@ func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, o
333329
c.opMutex.Lock()
334330
defer c.opMutex.Unlock()
335331
}
332+
c.wdMark()
333+
defer c.wdUnMark()
336334
c.setBusy(true)
337335
defer c.setBusy(false)
338336
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
@@ -351,6 +349,8 @@ func (c *Client) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ..
351349
}
352350
c.setBusy(true)
353351
defer c.setBusy(false)
352+
c.wdMark()
353+
defer c.wdUnMark()
354354
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
355355
if err != nil {
356356
return nil, err
@@ -367,6 +367,8 @@ func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts
367367
}
368368
c.setBusy(true)
369369
defer c.setBusy(false)
370+
c.wdMark()
371+
defer c.wdUnMark()
370372
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
371373
if err != nil {
372374
return nil, err
@@ -383,6 +385,8 @@ func (c *Client) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.
383385
}
384386
c.setBusy(true)
385387
defer c.setBusy(false)
388+
c.wdMark()
389+
defer c.wdUnMark()
386390
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
387391
if err != nil {
388392
return nil, err

pkg/model/initializers.go

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ ENTRY:
8080
if e.IsDir() {
8181
continue
8282
}
83+
if strings.HasSuffix(e.Name(), ".log") {
84+
continue
85+
}
8386

8487
// Skip the llama.cpp variants if we are autoDetecting
8588
// But we always load the fallback variant if it exists
@@ -265,12 +268,12 @@ func selectGRPCProcess(backend, assetDir string, f16 bool) string {
265268

266269
// starts the grpcModelProcess for the backend, and returns a grpc client
267270
// It also loads the model
268-
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (ModelAddress, error) {
269-
return func(modelName, modelFile string) (ModelAddress, error) {
271+
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (*Model, error) {
272+
return func(modelName, modelFile string) (*Model, error) {
270273

271274
log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o)
272275

273-
var client ModelAddress
276+
var client *Model
274277

275278
getFreeAddress := func() (string, error) {
276279
port, err := freeport.GetFreePort()
@@ -298,26 +301,26 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
298301
log.Debug().Msgf("external backend is file: %+v", fi)
299302
serverAddress, err := getFreeAddress()
300303
if err != nil {
301-
return "", fmt.Errorf("failed allocating free ports: %s", err.Error())
304+
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
302305
}
303306
// Make sure the process is executable
304307
if err := ml.startProcess(uri, o.model, serverAddress); err != nil {
305308
log.Error().Err(err).Str("path", uri).Msg("failed to launch ")
306-
return "", err
309+
return nil, err
307310
}
308311

309312
log.Debug().Msgf("GRPC Service Started")
310313

311-
client = ModelAddress(serverAddress)
314+
client = NewModel(serverAddress)
312315
} else {
313316
log.Debug().Msg("external backend is uri")
314317
// address
315-
client = ModelAddress(uri)
318+
client = NewModel(uri)
316319
}
317320
} else {
318321
grpcProcess := backendPath(o.assetDir, backend)
319322
if err := utils.VerifyPath(grpcProcess, o.assetDir); err != nil {
320-
return "", fmt.Errorf("grpc process not found in assetdir: %s", err.Error())
323+
return nil, fmt.Errorf("grpc process not found in assetdir: %s", err.Error())
321324
}
322325

323326
if autoDetect {
@@ -329,12 +332,12 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
329332

330333
// Check if the file exists
331334
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
332-
return "", fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
335+
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
333336
}
334337

335338
serverAddress, err := getFreeAddress()
336339
if err != nil {
337-
return "", fmt.Errorf("failed allocating free ports: %s", err.Error())
340+
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
338341
}
339342

340343
args := []string{}
@@ -344,12 +347,12 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
344347

345348
// Make sure the process is executable in any circumstance
346349
if err := ml.startProcess(grpcProcess, o.model, serverAddress, args...); err != nil {
347-
return "", err
350+
return nil, err
348351
}
349352

350353
log.Debug().Msgf("GRPC Service Started")
351354

352-
client = ModelAddress(serverAddress)
355+
client = NewModel(serverAddress)
353356
}
354357

355358
// Wait for the service to start up
@@ -369,7 +372,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
369372

370373
if !ready {
371374
log.Debug().Msgf("GRPC Service NOT ready")
372-
return "", fmt.Errorf("grpc service not ready")
375+
return nil, fmt.Errorf("grpc service not ready")
373376
}
374377

375378
options := *o.gRPCOptions
@@ -380,27 +383,16 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
380383

381384
res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options)
382385
if err != nil {
383-
return "", fmt.Errorf("could not load model: %w", err)
386+
return nil, fmt.Errorf("could not load model: %w", err)
384387
}
385388
if !res.Success {
386-
return "", fmt.Errorf("could not load model (no success): %s", res.Message)
389+
return nil, fmt.Errorf("could not load model (no success): %s", res.Message)
387390
}
388391

389392
return client, nil
390393
}
391394
}
392395

393-
func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (grpc.Backend, error) {
394-
if parallel {
395-
return addr.GRPC(parallel, ml.wd), nil
396-
}
397-
398-
if _, ok := ml.grpcClients[string(addr)]; !ok {
399-
ml.grpcClients[string(addr)] = addr.GRPC(parallel, ml.wd)
400-
}
401-
return ml.grpcClients[string(addr)], nil
402-
}
403-
404396
func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) {
405397
o := NewOptions(opts...)
406398

@@ -425,7 +417,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
425417
log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel")
426418
return nil, err
427419
}
428-
429420
}
430421

431422
var backendToConsume string
@@ -438,26 +429,28 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
438429
backendToConsume = backend
439430
}
440431

441-
addr, err := ml.LoadModel(o.model, ml.grpcModel(backendToConsume, o))
432+
model, err := ml.LoadModel(o.model, ml.grpcModel(backendToConsume, o))
442433
if err != nil {
443434
return nil, err
444435
}
445436

446-
return ml.resolveAddress(addr, o.parallelRequests)
437+
return model.GRPC(o.parallelRequests, ml.wd), nil
447438
}
448439

449440
func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
450441
o := NewOptions(opts...)
451442

452443
ml.mu.Lock()
444+
453445
// Return earlier if we have a model already loaded
454446
// (avoid looping through all the backends)
455-
if m := ml.CheckIsLoaded(o.model); m != "" {
447+
if m := ml.CheckIsLoaded(o.model); m != nil {
456448
log.Debug().Msgf("Model '%s' already loaded", o.model)
457449
ml.mu.Unlock()
458450

459-
return ml.resolveAddress(m, o.parallelRequests)
451+
return m.GRPC(o.parallelRequests, ml.wd), nil
460452
}
453+
461454
// If we can have only one backend active, kill all the others (except external backends)
462455
if o.singleActiveBackend {
463456
log.Debug().Msgf("Stopping all backends except '%s'", o.model)

0 commit comments

Comments
 (0)