Skip to content

Commit 74eaf02

Browse files
authored
feat(diffusers): support flux models (#3129)
* feat(diffusers): support flux models This adds support for FLUX models. For instance: https://huggingface.co/black-forest-labs/FLUX.1-dev Signed-off-by: Ettore Di Giacinto <[email protected]> * feat(diffusers): support FluxTransformer2DModel Signed-off-by: Ettore Di Giacinto <[email protected]> * Small fixups Signed-off-by: Ettore Di Giacinto <[email protected]> --------- Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 7ba4a78 commit 74eaf02

File tree

6 files changed

+45
-7
lines changed

6 files changed

+45
-7
lines changed

backend/python/diffusers/backend.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
import grpc
1919

2020
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
21-
EulerAncestralDiscreteScheduler
21+
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
2222
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
2323
from diffusers.pipelines.stable_diffusion import safety_checker
2424
from diffusers.utils import load_image, export_to_video
2525
from compel import Compel, ReturnedEmbeddingsType
26-
27-
from transformers import CLIPTextModel
26+
from optimum.quanto import freeze, qfloat8, quantize
27+
from transformers import CLIPTextModel, T5EncoderModel
2828
from safetensors.torch import load_file
2929

3030
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -163,6 +163,8 @@ def LoadModel(self, request, context):
163163
modelFile = request.Model
164164

165165
self.cfg_scale = 7
166+
self.PipelineType = request.PipelineType
167+
166168
if request.CFGScale != 0:
167169
self.cfg_scale = request.CFGScale
168170

@@ -244,6 +246,30 @@ def LoadModel(self, request, context):
244246
torch_dtype=torchType,
245247
use_safetensors=True,
246248
variant=variant)
249+
elif request.PipelineType == "FluxPipeline":
250+
self.pipe = FluxPipeline.from_pretrained(
251+
request.Model,
252+
torch_dtype=torch.bfloat16)
253+
if request.LowVRAM:
254+
self.pipe.enable_model_cpu_offload()
255+
elif request.PipelineType == "FluxTransformer2DModel":
256+
dtype = torch.bfloat16
257+
# specify from environment or default to "ChuckMcSneed/FLUX.1-dev"
258+
bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev")
259+
260+
transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype)
261+
quantize(transformer, weights=qfloat8)
262+
freeze(transformer)
263+
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
264+
quantize(text_encoder_2, weights=qfloat8)
265+
freeze(text_encoder_2)
266+
267+
self.pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
268+
self.pipe.transformer = transformer
269+
self.pipe.text_encoder_2 = text_encoder_2
270+
271+
if request.LowVRAM:
272+
self.pipe.enable_model_cpu_offload()
247273

248274
if CLIPSKIP and request.CLIPSkip != 0:
249275
self.clip_skip = request.CLIPSkip
@@ -399,6 +425,13 @@ def GenerateImage(self, request, context):
399425
request.seed
400426
)
401427

428+
if self.PipelineType == "FluxPipeline":
429+
kwargs["max_sequence_length"] = 256
430+
431+
if self.PipelineType == "FluxTransformer2DModel":
432+
kwargs["output_type"] = "pil"
433+
kwargs["generator"] = torch.Generator("cpu").manual_seed(0)
434+
402435
if self.img2vid:
403436
# Load the conditioning image
404437
image = load_image(request.src)

backend/python/diffusers/requirements-cpu.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ accelerate
55
compel
66
peft
77
sentencepiece
8-
torch
8+
torch
9+
optimum-quanto

backend/python/diffusers/requirements-cublas11.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ transformers
66
accelerate
77
compel
88
peft
9-
sentencepiece
9+
sentencepiece
10+
optimum-quanto

backend/python/diffusers/requirements-cublas12.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ transformers
55
accelerate
66
compel
77
peft
8-
sentencepiece
8+
sentencepiece
9+
optimum-quanto

backend/python/diffusers/requirements-hipblas.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ accelerate
88
compel
99
peft
1010
sentencepiece
11+
optimum-quanto

backend/python/diffusers/requirements-intel.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ transformers
1010
accelerate
1111
compel
1212
peft
13-
sentencepiece
13+
sentencepiece
14+
optimum-quanto

0 commit comments

Comments
 (0)