Skip to content

Commit 58eb317

Browse files
committed
Memory check before inference to avoid VAE Decode using exceeded VRAM.
Check if free memory is not less than expected before doing actual decoding, and if it fails, try to free for required amount of memory, and if it still fails, switch to tiled VAE decoding directly. It seems PyTorch may continue occupying memory until the model is destroyed after OOM occurs. This commit tries to avoid OOM from happening in the first place for VAE Decode. This is for VAE Decode ran with exceeded VRAM from #5737.
1 parent 839ed33 commit 58eb317

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

comfy/model_management.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
425425
for i in range(len(current_loaded_models) -1, -1, -1):
426426
shift_model = current_loaded_models[i]
427427
if shift_model.device == device:
428-
if shift_model not in keep_loaded:
428+
if shift_model not in keep_loaded and shift_model.model not in keep_loaded:
429429
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
430430
shift_model.currently_used = False
431431

comfy/sd.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,11 +348,24 @@ def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
348348
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
349349

350350
def decode(self, samples_in):
351+
predicted_oom = False
352+
samples = None
353+
out = None
351354
pixel_samples = None
352355
try:
353356
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
354357
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
355358
free_memory = model_management.get_free_memory(self.device)
359+
logging.debug(f"Free memory: {free_memory} bytes, predicted memory useage of one batch: {memory_used} bytes")
360+
if free_memory < memory_used:
361+
logging.debug("Possible out of memory is detected, try to free memory.")
362+
model_management.free_memory(memory_used, self.device, [self.patcher])
363+
free_memory = model_management.get_free_memory(self.device)
364+
logging.debug(f"Free memory: {free_memory} bytes")
365+
if free_memory < memory_used:
366+
logging.warning("Warning: Out of memory is predicted for regular VAE decoding, directly switch to tiled VAE decoding.")
367+
predicted_oom = True
368+
raise model_management.OOM_EXCEPTION
356369
batch_number = int(free_memory / memory_used)
357370
batch_number = max(1, batch_number)
358371

@@ -363,7 +376,11 @@ def decode(self, samples_in):
363376
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
364377
pixel_samples[x:x+batch_number] = out
365378
except model_management.OOM_EXCEPTION as e:
366-
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
379+
samples = None
380+
out = None
381+
pixel_samples = None
382+
if not predicted_oom:
383+
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
367384
dims = samples_in.ndim - 2
368385
if dims == 1:
369386
pixel_samples = self.decode_tiled_1d(samples_in)

0 commit comments

Comments
 (0)