Split VAE decoding out from the FLUXTextToImageInvocation.

This commit is contained in:
Ryan Dick 2024-08-29 14:50:58 +00:00
parent b33cba500c
commit 4384858be2
2 changed files with 8 additions and 27 deletions

View File

@ -1,8 +1,6 @@
from typing import Optional from typing import Optional
import torch import torch
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
@ -15,11 +13,10 @@ from invokeai.app.invocations.fields import (
WithMetadata, WithMetadata,
) )
from invokeai.app.invocations.model import TransformerField, VAEField from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -65,7 +62,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.") height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField( num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50." default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
) )
guidance: float = InputField( guidance: float = InputField(
default=4.0, default=4.0,
@ -74,11 +71,12 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
seed: int = InputField(default=0, description="Randomness seed for reproducibility.") seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context) latents = self._run_diffusion(context)
image = self._run_vae_decoding(context, latents) latents = latents.detach().to("cpu")
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto) name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _run_diffusion( def _run_diffusion(
self, self,
@ -185,20 +183,3 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
x = unpack(x.float(), self.height, self.width) x = unpack(x.float(), self.height, self.width)
return x return x
def _run_vae_decoding(
self,
context: InvocationContext,
latents: torch.Tensor,
) -> Image.Image:
vae_info = context.models.load(self.vae.vae)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil

View File

@ -119,7 +119,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def _vae_decode_flux(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image: def _vae_decode_flux(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
with vae_info as vae: with vae_info as vae:
assert isinstance(vae, AutoEncoder) assert isinstance(vae, AutoEncoder)
latents = latents.to(dtype=TorchDevice.choose_torch_dtype()) latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents) img = vae.decode(latents)
img = img.clamp(-1, 1) img = img.clamp(-1, 1)