|
18 | 18 | import grpc
|
19 | 19 |
|
20 | 20 | from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
|
21 |
| - EulerAncestralDiscreteScheduler |
| 21 | + EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel |
22 | 22 | from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
|
23 | 23 | from diffusers.pipelines.stable_diffusion import safety_checker
|
24 | 24 | from diffusers.utils import load_image, export_to_video
|
25 | 25 | from compel import Compel, ReturnedEmbeddingsType
|
26 |
| - |
27 |
| -from transformers import CLIPTextModel |
| 26 | +from optimum.quanto import freeze, qfloat8, quantize |
| 27 | +from transformers import CLIPTextModel, T5EncoderModel |
28 | 28 | from safetensors.torch import load_file
|
29 | 29 |
|
30 | 30 | _ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
@@ -163,6 +163,8 @@ def LoadModel(self, request, context):
|
163 | 163 | modelFile = request.Model
|
164 | 164 |
|
165 | 165 | self.cfg_scale = 7
|
| 166 | + self.PipelineType = request.PipelineType |
| 167 | + |
166 | 168 | if request.CFGScale != 0:
|
167 | 169 | self.cfg_scale = request.CFGScale
|
168 | 170 |
|
@@ -244,6 +246,30 @@ def LoadModel(self, request, context):
|
244 | 246 | torch_dtype=torchType,
|
245 | 247 | use_safetensors=True,
|
246 | 248 | variant=variant)
|
| 249 | + elif request.PipelineType == "FluxPipeline": |
| 250 | + self.pipe = FluxPipeline.from_pretrained( |
| 251 | + request.Model, |
| 252 | + torch_dtype=torch.bfloat16) |
| 253 | + if request.LowVRAM: |
| 254 | + self.pipe.enable_model_cpu_offload() |
| 255 | + elif request.PipelineType == "FluxTransformer2DModel": |
| 256 | + dtype = torch.bfloat16 |
| 257 | + # specify from environment or default to "ChuckMcSneed/FLUX.1-dev" |
| 258 | + bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev") |
| 259 | + |
| 260 | + transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype) |
| 261 | + quantize(transformer, weights=qfloat8) |
| 262 | + freeze(transformer) |
| 263 | + text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype) |
| 264 | + quantize(text_encoder_2, weights=qfloat8) |
| 265 | + freeze(text_encoder_2) |
| 266 | + |
| 267 | + self.pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype) |
| 268 | + self.pipe.transformer = transformer |
| 269 | + self.pipe.text_encoder_2 = text_encoder_2 |
| 270 | + |
| 271 | + if request.LowVRAM: |
| 272 | + self.pipe.enable_model_cpu_offload() |
247 | 273 |
|
248 | 274 | if CLIPSKIP and request.CLIPSkip != 0:
|
249 | 275 | self.clip_skip = request.CLIPSkip
|
@@ -399,6 +425,13 @@ def GenerateImage(self, request, context):
|
399 | 425 | request.seed
|
400 | 426 | )
|
401 | 427 |
|
| 428 | + if self.PipelineType == "FluxPipeline": |
| 429 | + kwargs["max_sequence_length"] = 256 |
| 430 | + |
| 431 | + if self.PipelineType == "FluxTransformer2DModel": |
| 432 | + kwargs["output_type"] = "pil" |
| 433 | + kwargs["generator"] = torch.Generator("cpu").manual_seed(0) |
| 434 | + |
402 | 435 | if self.img2vid:
|
403 | 436 | # Load the conditioning image
|
404 | 437 | image = load_image(request.src)
|
|
0 commit comments