mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Expose LatentsToImageInvocation.vae_decode(...) as a staticmethod.
This commit is contained in:
parent
f9c61f1b6c
commit
5461de41df
@ -10,6 +10,7 @@ from diffusers.models.attention_processor import (
|
||||
)
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
|
||||
@ -24,6 +25,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion import set_seamless
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@ -34,7 +36,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i"],
|
||||
category="latents",
|
||||
version="1.3.0",
|
||||
version="1.3.1",
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
@ -53,16 +55,21 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
@staticmethod
|
||||
def vae_decode(
|
||||
context: InvocationContext,
|
||||
vae_info: LoadedModel,
|
||||
seamless_axes: list[str],
|
||||
latents: torch.Tensor,
|
||||
use_fp32: bool,
|
||||
use_tiling: bool,
|
||||
tile_size: int,
|
||||
) -> Image.Image:
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
if use_fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||
@ -87,17 +94,17 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
if self.tiled or context.config.get().force_tiled_decode:
|
||||
if use_tiling or context.config.get().force_tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
tiling_context = nullcontext()
|
||||
if self.tile_size > 0:
|
||||
if tile_size > 0:
|
||||
tiling_context = patch_vae_tiling_params(
|
||||
vae,
|
||||
tile_sample_min_size=self.tile_size,
|
||||
tile_latent_min_size=self.tile_size // LATENT_SCALE_FACTOR,
|
||||
tile_sample_min_size=tile_size,
|
||||
tile_latent_min_size=tile_size // LATENT_SCALE_FACTOR,
|
||||
tile_overlap_factor=0.25,
|
||||
)
|
||||
|
||||
@ -116,6 +123,24 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
image = self.vae_decode(
|
||||
context,
|
||||
vae_info,
|
||||
seamless_axes=self.vae.seamless_axes,
|
||||
latents=latents,
|
||||
use_fp32=self.fp32,
|
||||
use_tiling=self.tiled,
|
||||
tile_size=self.tile_size,
|
||||
)
|
||||
|
||||
image_dto = context.images.save(image=image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
Loading…
Reference in New Issue
Block a user