mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add config to limit size of images in denoising
This serves as a relatively crude way to prevent OOM errors during denoising (and any operations downstream of the denoising step, like the VAE decode in Linear UI graphs). - Add `max_image_size` config options - this is the total number of pixels eg the area - Add logic to `denoise_latents` to scale the `latents` and `noise` to fit this - Add logic to `color_correct` to scale the reference and mask to fit the image
This commit is contained in:
parent
b152fbf72f
commit
8918501869
@ -645,13 +645,19 @@ class ColorCorrectInvocation(BaseInvocation):
|
|||||||
mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
|
mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
|
||||||
|
|
||||||
|
init_image = context.services.images.get_pil_image(self.reference.image_name)
|
||||||
|
# fit reference image to the input image
|
||||||
|
if init_image.size != result.size:
|
||||||
|
init_image = init_image.resize((result.width, result.height), Image.BILINEAR)
|
||||||
|
|
||||||
pil_init_mask = None
|
pil_init_mask = None
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
pil_init_mask = context.services.images.get_pil_image(self.mask.image_name).convert("L")
|
pil_init_mask = context.services.images.get_pil_image(self.mask.image_name).convert("L")
|
||||||
|
# fit mask to the input image
|
||||||
init_image = context.services.images.get_pil_image(self.reference.image_name)
|
if pil_init_mask.size != result.size:
|
||||||
|
pil_init_mask = pil_init_mask.resize((result.width, result.height), Image.BILINEAR)
|
||||||
result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
|
|
||||||
|
|
||||||
# if init_image is None or init_mask is None:
|
# if init_image is None or init_mask is None:
|
||||||
# return result
|
# return result
|
||||||
|
@ -77,6 +77,25 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|||||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||||
|
|
||||||
|
|
||||||
|
def fit_latents(latents: torch.Tensor, max_latents_size: int, device: torch.device) -> torch.Tensor:
|
||||||
|
if max_latents_size == 0:
|
||||||
|
return latents
|
||||||
|
|
||||||
|
latents_area = latents.shape[2] * latents.shape[3]
|
||||||
|
if latents_area <= max_latents_size:
|
||||||
|
return latents
|
||||||
|
|
||||||
|
scale_factor = np.sqrt(max_latents_size / latents_area)
|
||||||
|
scaled_latents = torch.nn.functional.interpolate(
|
||||||
|
latents.to(device),
|
||||||
|
scale_factor=scale_factor,
|
||||||
|
mode="bilinear",
|
||||||
|
antialias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return scaled_latents
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("scheduler_output")
|
@invocation_output("scheduler_output")
|
||||||
class SchedulerOutput(BaseInvocationOutput):
|
class SchedulerOutput(BaseInvocationOutput):
|
||||||
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||||
@ -500,14 +519,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||||
seed = None
|
seed = None
|
||||||
noise = None
|
noise = None
|
||||||
|
max_image_size = context.services.configuration.max_image_size
|
||||||
|
|
||||||
if self.noise is not None:
|
if self.noise is not None:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
seed = self.noise.seed
|
seed = self.noise.seed
|
||||||
|
noise = fit_latents(latents=noise, max_latents_size=max_image_size // 64, device=choose_torch_device())
|
||||||
|
|
||||||
if self.latents is not None:
|
if self.latents is not None:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = self.latents.seed
|
seed = self.latents.seed
|
||||||
|
latents = fit_latents(
|
||||||
|
latents=latents, max_latents_size=max_image_size // 64, device=choose_torch_device()
|
||||||
|
)
|
||||||
|
|
||||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||||
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||||
@ -532,8 +557,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}),
|
**lora.dict(exclude={"weight"}), context=context
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
|
@ -255,6 +255,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||||
|
max_image_size : int = Field(default=512 * 512, description="The maximum size of images, in pixels. The maximum size for latents is inferred from this evaluating `max_image_size // 8`. If the size is exceeded during denoising, the latents will be resized.", category="Generation", )
|
||||||
|
|
||||||
# QUEUE
|
# QUEUE
|
||||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
|
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
|
||||||
|
Loading…
Reference in New Issue
Block a user