From 6102e560ba7f17e91f6efbd1711d5c3d353a763d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 5 May 2023 15:15:55 +1000 Subject: [PATCH] feat(nodes): add LatentsToImage node (VAE encode) --- invokeai/app/invocations/latent.py | 50 ++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 0d3ef4a8cd..01c7623e03 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,7 +1,8 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) import random -from typing import Literal, Optional +from typing import Literal, Optional, Union +import einops from pydantic import BaseModel, Field import torch @@ -13,7 +14,8 @@ from ...backend.model_management.model_manager import ModelManager from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.image_util.seamless import configure_model_padding -from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline +from ...backend.prompting.conditioning import get_uc_and_c_and_ec +from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig import numpy as np from ..services.image_storage import ImageType @@ -433,3 +435,47 @@ class ScaleLatentsInvocation(BaseInvocation): name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.set(name, resized_latents) return LatentsOutput(latents=LatentsField(latents_name=name)) + + +class ImageToLatentsInvocation(BaseInvocation): + """Encodes an image into latents.""" + + type: Literal["i2l"] = "i2l" + + # Inputs + image: Union[ImageField, None] = Field(description="The image to encode") + model: str = Field(default="", description="The model to use") + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["latents", "image"], + "type_hints": {"model": "model"}, + }, + } + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> LatentsOutput: + image = context.services.images.get( + self.image.image_type, self.image.image_name + ) + + # TODO: this only really needs the vae + model_info = choose_model(context.services.model_manager, self.model) + model: StableDiffusionGeneratorPipeline = model_info["model"] + + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + + if image_tensor.dim() == 3: + image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") + + latents = model.non_noised_latents_from_image( + image_tensor, + device=model._model_group.device_for(model.unet), + dtype=model.unet.dtype, + ) + + name = f"{context.graph_execution_state_id}__{self.id}" + context.services.latents.set(name, latents) + return LatentsOutput(latents=LatentsField(latents_name=name))