From 4384858be270fa82d29141a50b8934e416f7b86c Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 29 Aug 2024 14:50:58 +0000 Subject: [PATCH] Split VAE decoding out from the FLUXTextToImageInvocation. --- .../app/invocations/flux_text_to_image.py | 33 ++++--------------- invokeai/app/invocations/latents_to_image.py | 2 +- 2 files changed, 8 insertions(+), 27 deletions(-) diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index dacc282543..d123886488 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -1,8 +1,6 @@ from typing import Optional import torch -from einops import rearrange -from PIL import Image from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation from invokeai.app.invocations.fields import ( @@ -15,11 +13,10 @@ from invokeai.app.invocations.fields import ( WithMetadata, ) from invokeai.app.invocations.model import TransformerField, VAEField -from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice @@ -65,7 +62,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.") num_steps: int = InputField( - default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50." + default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50." ) guidance: float = InputField( default=4.0, @@ -74,11 +71,12 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): seed: int = InputField(default=0, description="Randomness seed for reproducibility.") @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = self._run_diffusion(context) - image = self._run_vae_decoding(context, latents) - image_dto = context.images.save(image=image) - return ImageOutput.build(image_dto) + latents = latents.detach().to("cpu") + + name = context.tensors.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) def _run_diffusion( self, @@ -185,20 +183,3 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): x = unpack(x.float(), self.height, self.width) return x - - def _run_vae_decoding( - self, - context: InvocationContext, - latents: torch.Tensor, - ) -> Image.Image: - vae_info = context.models.load(self.vae.vae) - with vae_info as vae: - assert isinstance(vae, AutoEncoder) - latents = latents.to(dtype=TorchDevice.choose_torch_dtype()) - img = vae.decode(latents) - - img = img.clamp(-1, 1) - img = rearrange(img[0], "c h w -> h w c") - img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) - - return img_pil diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index 90885d5c40..660e7e84d5 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -119,7 +119,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def _vae_decode_flux(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image: with vae_info as vae: assert isinstance(vae, AutoEncoder) - latents = latents.to(dtype=TorchDevice.choose_torch_dtype()) + latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()) img = vae.decode(latents) img = img.clamp(-1, 1)