Skip to content

Commit 81ae92f

Browse files
dave-gray101mudler
andauthored
feat: elevenlabs sound-generation api (#3355)
* initial version of elevenlabs compatible soundgeneration api and cli command Signed-off-by: Dave Lee <[email protected]> * minor cleanup Signed-off-by: Dave Lee <[email protected]> * restore TTS, add test Signed-off-by: Dave Lee <[email protected]> * remove stray s Signed-off-by: Dave Lee <[email protected]> * fix Signed-off-by: Dave Lee <[email protected]> --------- Signed-off-by: Dave Lee <[email protected]> Signed-off-by: Ettore Di Giacinto <[email protected]> Co-authored-by: Ettore Di Giacinto <[email protected]>
1 parent 84d6e5a commit 81ae92f

File tree

20 files changed

+450
-37
lines changed

20 files changed

+450
-37
lines changed

backend/backend.proto

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ service Backend {
1616
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
1717
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
1818
rpc TTS(TTSRequest) returns (Result) {}
19+
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
1920
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
2021
rpc Status(HealthMessage) returns (StatusResponse) {}
2122

@@ -270,6 +271,17 @@ message TTSRequest {
270271
optional string language = 5;
271272
}
272273

274+
message SoundGenerationRequest {
275+
string text = 1;
276+
string model = 2;
277+
string dst = 3;
278+
optional float duration = 4;
279+
optional float temperature = 5;
280+
optional bool sample = 6;
281+
optional string src = 7;
282+
optional int32 src_divisor = 8;
283+
}
284+
273285
message TokenizationResponse {
274286
int32 length = 1;
275287
repeated int32 tokens = 2;

backend/python/transformers-musicgen/backend.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import grpc
1717

18-
from scipy.io.wavfile import write as write_wav
18+
from scipy.io import wavfile
1919
from transformers import AutoProcessor, MusicgenForConditionalGeneration
2020

2121
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -63,6 +63,61 @@ def LoadModel(self, request, context):
6363

6464
return backend_pb2.Result(message="Model loaded successfully", success=True)
6565

66+
def SoundGeneration(self, request, context):
67+
model_name = request.model
68+
if model_name == "":
69+
return backend_pb2.Result(success=False, message="request.model is required")
70+
try:
71+
self.processor = AutoProcessor.from_pretrained(model_name)
72+
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
73+
inputs = None
74+
if request.text == "":
75+
inputs = self.model.get_unconditional_inputs(num_samples=1)
76+
elif request.HasField('src'):
77+
# TODO SECURITY CODE GOES HERE LOL
78+
# WHO KNOWS IF THIS WORKS???
79+
sample_rate, wsamples = wavfile.read('path_to_your_file.wav')
80+
81+
if request.HasField('src_divisor'):
82+
wsamples = wsamples[: len(wsamples) // request.src_divisor]
83+
84+
inputs = self.processor(
85+
audio=wsamples,
86+
sampling_rate=sample_rate,
87+
text=[request.text],
88+
padding=True,
89+
return_tensors="pt",
90+
)
91+
else:
92+
inputs = self.processor(
93+
text=[request.text],
94+
padding=True,
95+
return_tensors="pt",
96+
)
97+
98+
tokens = 256
99+
if request.HasField('duration'):
100+
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
101+
guidance = 3.0
102+
if request.HasField('temperature'):
103+
guidance = request.temperature
104+
dosample = True
105+
if request.HasField('sample'):
106+
dosample = request.sample
107+
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens)
108+
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
109+
sampling_rate = self.model.config.audio_encoder.sampling_rate
110+
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
111+
print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr)
112+
print("[transformers-musicgen] SoundGeneration for", file=sys.stderr)
113+
print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr)
114+
print(request, file=sys.stderr)
115+
except Exception as err:
116+
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
117+
return backend_pb2.Result(success=True)
118+
119+
120+
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
66121
def TTS(self, request, context):
67122
model_name = request.model
68123
if model_name == "":
@@ -75,8 +130,7 @@ def TTS(self, request, context):
75130
padding=True,
76131
return_tensors="pt",
77132
)
78-
tokens = 256
79-
# TODO get tokens from request?
133+
tokens = 512 # No good place to set the "length" in TTS, so use 10s as a sane default
80134
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
81135
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
82136
sampling_rate = self.model.config.audio_encoder.sampling_rate

backend/python/transformers-musicgen/test.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_load_model(self):
6363

6464
def test_tts(self):
6565
"""
66-
This method tests if the embeddings are generated successfully
66+
This method tests if TTS is generated successfully
6767
"""
6868
try:
6969
self.setUp()
@@ -77,5 +77,24 @@ def test_tts(self):
7777
except Exception as err:
7878
print(err)
7979
self.fail("TTS service failed")
80+
finally:
81+
self.tearDown()
82+
83+
def test_sound_generation(self):
84+
"""
85+
This method tests if SoundGeneration is generated successfully
86+
"""
87+
try:
88+
self.setUp()
89+
with grpc.insecure_channel("localhost:50051") as channel:
90+
stub = backend_pb2_grpc.BackendStub(channel)
91+
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small"))
92+
self.assertTrue(response.success)
93+
sg_request = backend_pb2.SoundGenerationRequest(text="80s TV news production music hit for tonight's biggest story")
94+
sg_response = stub.SoundGeneration(sg_request)
95+
self.assertIsNotNone(sg_response)
96+
except Exception as err:
97+
print(err)
98+
self.fail("SoundGeneration service failed")
8099
finally:
81100
self.tearDown()

core/backend/llm.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
8787
case string:
8888
protoMessages[i].Content = ct
8989
default:
90-
return nil, fmt.Errorf("Unsupported type for schema.Message.Content for inference: %T", ct)
90+
return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
9191
}
9292
}
9393
}

core/backend/soundgeneration.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package backend
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
9+
"github.com/mudler/LocalAI/core/config"
10+
"github.com/mudler/LocalAI/pkg/grpc/proto"
11+
"github.com/mudler/LocalAI/pkg/model"
12+
"github.com/mudler/LocalAI/pkg/utils"
13+
)
14+
15+
func SoundGeneration(
16+
backend string,
17+
modelFile string,
18+
text string,
19+
duration *float32,
20+
temperature *float32,
21+
doSample *bool,
22+
sourceFile *string,
23+
sourceDivisor *int32,
24+
loader *model.ModelLoader,
25+
appConfig *config.ApplicationConfig,
26+
backendConfig config.BackendConfig,
27+
) (string, *proto.Result, error) {
28+
if backend == "" {
29+
return "", nil, fmt.Errorf("backend is a required parameter")
30+
}
31+
32+
grpcOpts := gRPCModelOpts(backendConfig)
33+
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
34+
model.WithBackendString(backend),
35+
model.WithModel(modelFile),
36+
model.WithContext(appConfig.Context),
37+
model.WithAssetDir(appConfig.AssetsDestination),
38+
model.WithLoadGRPCLoadModelOpts(grpcOpts),
39+
})
40+
41+
soundGenModel, err := loader.BackendLoader(opts...)
42+
if err != nil {
43+
return "", nil, err
44+
}
45+
46+
if soundGenModel == nil {
47+
return "", nil, fmt.Errorf("could not load sound generation model")
48+
}
49+
50+
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
51+
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
52+
}
53+
54+
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "sound_generation", ".wav")
55+
filePath := filepath.Join(appConfig.AudioDir, fileName)
56+
57+
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
58+
Text: text,
59+
Model: modelFile,
60+
Dst: filePath,
61+
Sample: doSample,
62+
Duration: duration,
63+
Temperature: temperature,
64+
Src: sourceFile,
65+
SrcDivisor: sourceDivisor,
66+
})
67+
68+
// return RPC error if any
69+
if !res.Success {
70+
return "", nil, fmt.Errorf(res.Message)
71+
}
72+
73+
return filePath, res, err
74+
}

core/backend/tts.go

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,15 @@ import (
99
"github.com/mudler/LocalAI/core/config"
1010

1111
"github.com/mudler/LocalAI/pkg/grpc/proto"
12-
model "github.com/mudler/LocalAI/pkg/model"
12+
"github.com/mudler/LocalAI/pkg/model"
1313
"github.com/mudler/LocalAI/pkg/utils"
1414
)
1515

16-
func generateUniqueFileName(dir, baseName, ext string) string {
17-
counter := 1
18-
fileName := baseName + ext
19-
20-
for {
21-
filePath := filepath.Join(dir, fileName)
22-
_, err := os.Stat(filePath)
23-
if os.IsNotExist(err) {
24-
return fileName
25-
}
26-
27-
counter++
28-
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
29-
}
30-
}
31-
3216
func ModelTTS(
3317
backend,
3418
text,
3519
modelFile,
36-
voice ,
20+
voice,
3721
language string,
3822
loader *model.ModelLoader,
3923
appConfig *config.ApplicationConfig,
@@ -66,7 +50,7 @@ func ModelTTS(
6650
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
6751
}
6852

69-
fileName := generateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
53+
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
7054
filePath := filepath.Join(appConfig.AudioDir, fileName)
7155

7256
// If the model file is not empty, we pass it joined with the model path
@@ -88,10 +72,10 @@ func ModelTTS(
8872
}
8973

9074
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
91-
Text: text,
92-
Model: modelPath,
93-
Voice: voice,
94-
Dst: filePath,
75+
Text: text,
76+
Model: modelPath,
77+
Voice: voice,
78+
Dst: filePath,
9579
Language: &language,
9680
})
9781

core/cli/cli.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ import (
88
var CLI struct {
99
cliContext.Context `embed:""`
1010

11-
Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
12-
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
13-
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
14-
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
15-
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
16-
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
17-
Util UtilCMD `cmd:"" help:"Utility commands"`
18-
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
11+
Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
12+
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
13+
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
14+
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
15+
SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"`
16+
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
17+
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
18+
Util UtilCMD `cmd:"" help:"Utility commands"`
19+
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
1920
}

0 commit comments

Comments
 (0)