SDXL Prompt and t2l nodes draft, add fp32 to vae decode

This commit is contained in:
Sergey Borisov
2023-07-11 18:19:36 +03:00
parent 34cff848c7
commit 358ced6bab
3 changed files with 537 additions and 2 deletions

View File

@ -28,6 +28,13 @@ from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
@ -449,6 +456,7 @@ class LatentsToImageInvocation(BaseInvocation):
tiled: bool = Field(
default=False,
description="Decode latents by overlaping tiles(less memory consumption)")
fp32: bool = Field(False, description="Decode in full precision")
# Schema customisation
class Config(InvocationConfig):
@ -467,6 +475,31 @@ class LatentsToImageInvocation(BaseInvocation):
)
with vae_info as vae:
if self.fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(latents.dtype)
vae.decoder.conv_in.to(latents.dtype)
vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
else:
vae.to(dtype=torch.float16)
latents = latents.half()
if self.tiled or context.services.configuration.tiled_decode:
vae.enable_tiling()
else: