feat: Remove TextToLatents / Rename Latents To Latents -> DenoiseLatents

This commit is contained in:
blessedcoolant 2023-08-11 22:20:37 +12:00
parent 231e665675
commit 7c0023ad9e

View File

@ -5,6 +5,7 @@ from typing import List, Literal, Optional, Union
import einops
import torch
import torchvision.transforms as T
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
@ -14,18 +15,14 @@ from diffusers.models.attention_processor import (
)
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management import ModelPatcher, BaseModelType
from ...backend.model_management import BaseModelType, ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
@ -35,11 +32,13 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
import torchvision.transforms as T
from torchvision.transforms.functional import resize as tv_resize
from ...backend.util.devices import choose_precision, choose_torch_device, torch_dtype
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
DEFAULT_PRECISION = choose_precision(choose_torch_device())
@ -106,26 +105,31 @@ def get_scheduler(
return scheduler
# Text to image
class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
type: Literal["t2l"] = "t2l"
type: Literal["denoise_latents"] = "denoise_latents"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(
default=7.5,
ge=1,
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
)
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
unet: UNetField = Field(default=None, description="UNet submodel")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
mask: Optional[ImageField] = Field(
None,
description="Mask",
)
@validator("cfg_scale")
def ge_one(cls, v):
@ -143,12 +147,11 @@ class TextToLatentsInvocation(BaseInvocation):
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Text To Latents",
"tags": ["latents"],
"title": "Denoise Latents",
"tags": ["denoise", "latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number",
},
},
@ -331,121 +334,6 @@ class TextToLatentsInvocation(BaseInvocation):
return num_inference_steps, timesteps
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings():
noise = context.services.latents.get(self.noise.latents_name)
seed = self.noise.seed
if seed is None:
seed = 0
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
control_data = self.prep_control_data(
model=pipeline,
context=context,
control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
num_inference_steps, timesteps = self.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=0.0,
denoising_end=self.denoising_end,
)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
seed=seed,
timesteps=timesteps,
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
# Inputs
noise: Optional[LatentsField] = Field(description="The noise to use (test override for future optional)")
# denoising_start = 1 - strength
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
#denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
mask: Optional[ImageField] = Field(
None, description="Mask",
)
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Latent To Latents",
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
"cfg_scale": "number",
},
},
}
def prep_mask_tensor(self, mask, context, lantents):
if mask is None:
return None
@ -457,9 +345,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
mask_tensor = tv_resize(
mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR
)
mask_tensor = tv_resize(mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR)
return 1 - mask_tensor
@torch.no_grad()