feat(nodes): add config to limit size of images in denoising

This serves as a relatively crude way to prevent OOM errors during denoising (and any operations downstream of the denoising step, like the VAE decode in Linear UI graphs).

- Add `max_image_size` config options - this is the total number of pixels eg the area
- Add logic to `denoise_latents` to scale the `latents` and `noise` to fit this
- Add logic to `color_correct` to scale the reference and mask to fit the image
This commit is contained in:
psychedelicious 2023-09-22 16:37:42 +10:00
parent b152fbf72f
commit 8918501869
3 changed files with 37 additions and 6 deletions

View File

@ -645,13 +645,19 @@ class ColorCorrectInvocation(BaseInvocation):
mask_blur_radius: float = InputField(default=8, description="Mask blur radius") mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
init_image = context.services.images.get_pil_image(self.reference.image_name)
# fit reference image to the input image
if init_image.size != result.size:
init_image = init_image.resize((result.width, result.height), Image.BILINEAR)
pil_init_mask = None pil_init_mask = None
if self.mask is not None: if self.mask is not None:
pil_init_mask = context.services.images.get_pil_image(self.mask.image_name).convert("L") pil_init_mask = context.services.images.get_pil_image(self.mask.image_name).convert("L")
# fit mask to the input image
init_image = context.services.images.get_pil_image(self.reference.image_name) if pil_init_mask.size != result.size:
pil_init_mask = pil_init_mask.resize((result.width, result.height), Image.BILINEAR)
result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
# if init_image is None or init_mask is None: # if init_image is None or init_mask is None:
# return result # return result

View File

@ -77,6 +77,25 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
def fit_latents(latents: torch.Tensor, max_latents_size: int, device: torch.device) -> torch.Tensor:
if max_latents_size == 0:
return latents
latents_area = latents.shape[2] * latents.shape[3]
if latents_area <= max_latents_size:
return latents
scale_factor = np.sqrt(max_latents_size / latents_area)
scaled_latents = torch.nn.functional.interpolate(
latents.to(device),
scale_factor=scale_factor,
mode="bilinear",
antialias=True,
)
return scaled_latents
@invocation_output("scheduler_output") @invocation_output("scheduler_output")
class SchedulerOutput(BaseInvocationOutput): class SchedulerOutput(BaseInvocationOutput):
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
@ -500,14 +519,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
with SilenceWarnings(): # this quenches NSFW nag from diffusers with SilenceWarnings(): # this quenches NSFW nag from diffusers
seed = None seed = None
noise = None noise = None
max_image_size = context.services.configuration.max_image_size
if self.noise is not None: if self.noise is not None:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
seed = self.noise.seed seed = self.noise.seed
noise = fit_latents(latents=noise, max_latents_size=max_image_size // 64, device=choose_torch_device())
if self.latents is not None: if self.latents is not None:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
if seed is None: if seed is None:
seed = self.latents.seed seed = self.latents.seed
latents = fit_latents(
latents=latents, max_latents_size=max_image_size // 64, device=choose_torch_device()
)
if noise is not None and noise.shape[1:] != latents.shape[1:]: if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}") raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
@ -532,8 +557,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}), **lora.dict(exclude={"weight"}), context=context
context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info

View File

@ -255,6 +255,7 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", ) attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",) force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",) force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
max_image_size : int = Field(default=512 * 512, description="The maximum size of images, in pixels. The maximum size for latents is inferred from this evaluating `max_image_size // 8`. If the size is exceeded during denoising, the latents will be resized.", category="Generation", )
# QUEUE # QUEUE
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", ) max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )