mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Split FLUX VAE decoding out into its own node from LatentsToImageInvocation.
This commit is contained in:
parent
6a89176c6a
commit
262b67b9cb
60
invokeai/app/invocations/flux_vae_decode.py
Normal file
60
invokeai/app/invocations/flux_vae_decode.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
LatentsField,
|
||||||
|
WithBoard,
|
||||||
|
WithMetadata,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.model import VAEField
|
||||||
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
|
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"flux_vae_decode",
|
||||||
|
title="FLUX VAE Decode",
|
||||||
|
tags=["latents", "image", "vae", "l2i", "flux"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
|
latents: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
vae: VAEField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||||
|
with vae_info as vae:
|
||||||
|
assert isinstance(vae, AutoEncoder)
|
||||||
|
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
|
||||||
|
img = vae.decode(latents)
|
||||||
|
|
||||||
|
img = img.clamp(-1, 1)
|
||||||
|
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
||||||
|
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||||
|
return img_pil
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||||
|
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
image_dto = context.images.save(image=image)
|
||||||
|
return ImageOutput.build(image_dto)
|
@ -10,8 +10,6 @@ from diffusers.models.attention_processor import (
|
|||||||
)
|
)
|
||||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||||
from einops import rearrange
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
|
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
|
||||||
@ -26,9 +24,6 @@ from invokeai.app.invocations.fields import (
|
|||||||
from invokeai.app.invocations.model import VAEField
|
from invokeai.app.invocations.model import VAEField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
|
||||||
from invokeai.backend.model_manager.config import BaseModelType
|
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
@ -39,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice
|
|||||||
title="Latents to Image",
|
title="Latents to Image",
|
||||||
tags=["latents", "image", "vae", "l2i"],
|
tags=["latents", "image", "vae", "l2i"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.4.0",
|
version="1.3.0",
|
||||||
)
|
)
|
||||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
@ -58,7 +53,10 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||||
|
|
||||||
def _vae_decode_stable_diffusion(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
|
vae_info = context.models.load(self.vae.vae)
|
||||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||||
@ -114,38 +112,6 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
|
||||||
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
||||||
return image
|
|
||||||
|
|
||||||
def _vae_decode_flux(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
|
||||||
with vae_info as vae:
|
|
||||||
assert isinstance(vae, AutoEncoder)
|
|
||||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
|
|
||||||
img = vae.decode(latents)
|
|
||||||
|
|
||||||
img = img.clamp(-1, 1)
|
|
||||||
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
|
||||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
|
||||||
return img_pil
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
|
||||||
|
|
||||||
vae_info = context.models.load(self.vae.vae)
|
|
||||||
if vae_info.config.base == BaseModelType.Flux:
|
|
||||||
if self.fp32:
|
|
||||||
raise NotImplementedError("FLUX VAE decode does not currently support fp32=True.")
|
|
||||||
if self.tiled:
|
|
||||||
raise NotImplementedError("FLUX VAE decode does not currently support tiled=True.")
|
|
||||||
image = self._vae_decode_flux(vae_info=vae_info, latents=latents)
|
|
||||||
elif vae_info.config.base in [
|
|
||||||
BaseModelType.StableDiffusion1,
|
|
||||||
BaseModelType.StableDiffusion2,
|
|
||||||
BaseModelType.StableDiffusionXL,
|
|
||||||
]:
|
|
||||||
image = self._vae_decode_stable_diffusion(vae_info=vae_info, latents=latents)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported VAE base type: '{vae_info.config.base}'")
|
|
||||||
|
|
||||||
TorchDevice.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user