diff --git a/invokeai/app/invocations/flux_vae_encode.py b/invokeai/app/invocations/flux_vae_encode.py new file mode 100644 index 0000000000..f69b421a91 --- /dev/null +++ b/invokeai/app/invocations/flux_vae_encode.py @@ -0,0 +1,67 @@ +import einops +import torch + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + Input, + InputField, +) +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.flux.modules.autoencoder import AutoEncoder +from invokeai.backend.model_manager import LoadedModel +from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "flux_vae_encode", + title="FLUX VAE Encode", + tags=["latents", "image", "vae", "i2l", "flux"], + category="latents", + version="1.0.0", +) +class FluxVaeEncodeInvocation(BaseInvocation): + """Encodes an image into latents.""" + + image: ImageField = InputField( + description="The image to encode.", + ) + vae: VAEField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + + @staticmethod + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: + # TODO(ryand): Expose seed parameter at the invocation level. + # TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes. + # There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function + # should be used for VAE encode sampling. + generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0) + with vae_info as vae: + assert isinstance(vae, AutoEncoder) + image_tensor = image_tensor.to( + device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype() + ) + latents = vae.encode(image_tensor, sample=True, generator=generator) + return latents + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> LatentsOutput: + image = context.images.get_pil(self.image.image_name) + + vae_info = context.models.load(self.vae.vae) + + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") + + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + + latents = latents.to("cpu") + name = context.tensors.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py index 99a92d98de..dadd8bb3a1 100644 --- a/invokeai/app/invocations/image_to_latents.py +++ b/invokeai/app/invocations/image_to_latents.py @@ -23,12 +23,9 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.model_manager import LoadedModel -from invokeai.backend.model_manager.config import BaseModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params -from invokeai.backend.util.devices import TorchDevice @invocation( @@ -36,7 +33,7 @@ from invokeai.backend.util.devices import TorchDevice title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", - version="1.2.0", + version="1.1.0", ) class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" @@ -55,22 +52,7 @@ class ImageToLatentsInvocation(BaseInvocation): fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) @staticmethod - def vae_encode_flux(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: - # TODO(ryand): Expose seed parameter at the invocation level. - # TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes. - # There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function - # should be used for VAE encode sampling. - generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0) - with vae_info as vae: - assert isinstance(vae, AutoEncoder) - image_tensor = image_tensor.to( - device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype() - ) - latents = vae.encode(image_tensor, sample=True, generator=generator) - return latents - - @staticmethod - def vae_encode_stable_diffusion( + def vae_encode( vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0 ) -> torch.Tensor: with vae_info as vae: @@ -125,27 +107,6 @@ class ImageToLatentsInvocation(BaseInvocation): return latents - @staticmethod - def vae_encode( - vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0 - ) -> torch.Tensor: - if vae_info.config.base == BaseModelType.Flux: - if upcast: - raise NotImplementedError("FLUX VAE encode does not currently support upcast=True.") - if tiled: - raise NotImplementedError("FLUX VAE encode does not currently support tiled=True.") - return ImageToLatentsInvocation.vae_encode_flux(vae_info=vae_info, image_tensor=image_tensor) - elif vae_info.config.base in [ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - ]: - return ImageToLatentsInvocation.vae_encode_stable_diffusion( - vae_info=vae_info, upcast=upcast, tiled=tiled, image_tensor=image_tensor, tile_size=tile_size - ) - else: - raise ValueError(f"Unsupported VAE base type: '{vae_info.config.base}'") - @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name)