mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Split FLUX VAE encoding out into its own node from ImageToLatentsInvocation.
This commit is contained in:
parent
7d854f32b0
commit
6a89176c6a
67
invokeai/app/invocations/flux_vae_encode.py
Normal file
67
invokeai/app/invocations/flux_vae_encode.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
ImageField,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.model import VAEField
|
||||||
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
|
from invokeai.backend.model_manager import LoadedModel
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"flux_vae_encode",
|
||||||
|
title="FLUX VAE Encode",
|
||||||
|
tags=["latents", "image", "vae", "i2l", "flux"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class FluxVaeEncodeInvocation(BaseInvocation):
|
||||||
|
"""Encodes an image into latents."""
|
||||||
|
|
||||||
|
image: ImageField = InputField(
|
||||||
|
description="The image to encode.",
|
||||||
|
)
|
||||||
|
vae: VAEField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
# TODO(ryand): Expose seed parameter at the invocation level.
|
||||||
|
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||||
|
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||||
|
# should be used for VAE encode sampling.
|
||||||
|
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||||
|
with vae_info as vae:
|
||||||
|
assert isinstance(vae, AutoEncoder)
|
||||||
|
image_tensor = image_tensor.to(
|
||||||
|
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||||
|
)
|
||||||
|
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
|
||||||
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
if image_tensor.dim() == 3:
|
||||||
|
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||||
|
|
||||||
|
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||||
|
|
||||||
|
latents = latents.to("cpu")
|
||||||
|
name = context.tensors.save(tensor=latents)
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
@ -23,12 +23,9 @@ from invokeai.app.invocations.fields import (
|
|||||||
from invokeai.app.invocations.model import VAEField
|
from invokeai.app.invocations.model import VAEField
|
||||||
from invokeai.app.invocations.primitives import LatentsOutput
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
|
||||||
from invokeai.backend.model_manager import LoadedModel
|
from invokeai.backend.model_manager import LoadedModel
|
||||||
from invokeai.backend.model_manager.config import BaseModelType
|
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -36,7 +33,7 @@ from invokeai.backend.util.devices import TorchDevice
|
|||||||
title="Image to Latents",
|
title="Image to Latents",
|
||||||
tags=["latents", "image", "vae", "i2l"],
|
tags=["latents", "image", "vae", "i2l"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.2.0",
|
version="1.1.0",
|
||||||
)
|
)
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
"""Encodes an image into latents."""
|
"""Encodes an image into latents."""
|
||||||
@ -55,22 +52,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def vae_encode_flux(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
def vae_encode(
|
||||||
# TODO(ryand): Expose seed parameter at the invocation level.
|
|
||||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
|
||||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
|
||||||
# should be used for VAE encode sampling.
|
|
||||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
|
||||||
with vae_info as vae:
|
|
||||||
assert isinstance(vae, AutoEncoder)
|
|
||||||
image_tensor = image_tensor.to(
|
|
||||||
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
|
||||||
)
|
|
||||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def vae_encode_stable_diffusion(
|
|
||||||
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
@ -125,27 +107,6 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def vae_encode(
|
|
||||||
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if vae_info.config.base == BaseModelType.Flux:
|
|
||||||
if upcast:
|
|
||||||
raise NotImplementedError("FLUX VAE encode does not currently support upcast=True.")
|
|
||||||
if tiled:
|
|
||||||
raise NotImplementedError("FLUX VAE encode does not currently support tiled=True.")
|
|
||||||
return ImageToLatentsInvocation.vae_encode_flux(vae_info=vae_info, image_tensor=image_tensor)
|
|
||||||
elif vae_info.config.base in [
|
|
||||||
BaseModelType.StableDiffusion1,
|
|
||||||
BaseModelType.StableDiffusion2,
|
|
||||||
BaseModelType.StableDiffusionXL,
|
|
||||||
]:
|
|
||||||
return ImageToLatentsInvocation.vae_encode_stable_diffusion(
|
|
||||||
vae_info=vae_info, upcast=upcast, tiled=tiled, image_tensor=image_tensor, tile_size=tile_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported VAE base type: '{vae_info.config.base}'")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
image = context.images.get_pil(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user