mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add LatentsToImage node (VAE encode)
This commit is contained in:
parent
ff3aa57117
commit
6102e560ba
@ -1,7 +1,8 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional, Union
|
||||||
|
import einops
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -13,7 +14,8 @@ from ...backend.model_management.model_manager import ModelManager
|
|||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.image_util.seamless import configure_model_padding
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||||
|
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_storage import ImageType
|
||||||
@ -433,3 +435,47 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, resized_latents)
|
context.services.latents.set(name, resized_latents)
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
|
"""Encodes an image into latents."""
|
||||||
|
|
||||||
|
type: Literal["i2l"] = "i2l"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||||
|
model: str = Field(default="", description="The model to use")
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["latents", "image"],
|
||||||
|
"type_hints": {"model": "model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
image = context.services.images.get(
|
||||||
|
self.image.image_type, self.image.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: this only really needs the vae
|
||||||
|
model_info = choose_model(context.services.model_manager, self.model)
|
||||||
|
model: StableDiffusionGeneratorPipeline = model_info["model"]
|
||||||
|
|
||||||
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
|
||||||
|
if image_tensor.dim() == 3:
|
||||||
|
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||||
|
|
||||||
|
latents = model.non_noised_latents_from_image(
|
||||||
|
image_tensor,
|
||||||
|
device=model._model_group.device_for(model.unet),
|
||||||
|
dtype=model.unet.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
context.services.latents.set(name, latents)
|
||||||
|
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user