mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Expose vae_decode(...) as a staticmethod on LatentsToImageInvocation.
This commit is contained in:
parent
020e8eb413
commit
21d7ca45e6
@ -9,6 +9,7 @@ from diffusers.models.attention_processor import (
|
|||||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||||
@ -23,6 +24,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
from invokeai.app.invocations.model import VAEField
|
from invokeai.app.invocations.model import VAEField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||||
from invokeai.backend.stable_diffusion import set_seamless
|
from invokeai.backend.stable_diffusion import set_seamless
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
@ -48,16 +50,20 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||||
|
|
||||||
@torch.no_grad()
|
@staticmethod
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def vae_decode(
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
context: InvocationContext,
|
||||||
|
vae_info: LoadedModel,
|
||||||
vae_info = context.models.load(self.vae.vae)
|
seamless_axes: list[str],
|
||||||
|
latents: torch.Tensor,
|
||||||
|
use_fp32: bool,
|
||||||
|
use_tiling: bool,
|
||||||
|
) -> Image.Image:
|
||||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
||||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
|
||||||
assert isinstance(vae, torch.nn.Module)
|
assert isinstance(vae, torch.nn.Module)
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
if self.fp32:
|
if use_fp32:
|
||||||
vae.to(dtype=torch.float32)
|
vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||||
@ -82,7 +88,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
vae.to(dtype=torch.float16)
|
vae.to(dtype=torch.float16)
|
||||||
latents = latents.half()
|
latents = latents.half()
|
||||||
|
|
||||||
if self.tiled or context.config.get().force_tiled_decode:
|
if use_tiling or context.config.get().force_tiled_decode:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
else:
|
else:
|
||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
@ -102,6 +108,21 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
TorchDevice.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
|
||||||
|
image = self.vae_decode(
|
||||||
|
context=context,
|
||||||
|
vae_info=vae_info,
|
||||||
|
seamless_axes=self.vae.seamless_axes,
|
||||||
|
latents=latents,
|
||||||
|
use_fp32=self.fp32,
|
||||||
|
use_tiling=self.tiled,
|
||||||
|
)
|
||||||
image_dto = context.images.save(image=image)
|
image_dto = context.images.save(image=image)
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
Loading…
Reference in New Issue
Block a user