BREAKING CHANGES: invocations now require model key, not base/type/name

- Implement new model loader and modify invocations and embeddings

- Finish implementation loaders for all models currently supported by
  InvokeAI.

- Move lora, textual_inversion, and model patching support into
  backend/embeddings.

- Restore support for model cache statistics collection (a little ugly,
  needs work).

- Fixed up invocations that load and patch models.

- Move seamless and silencewarnings utils into better location
This commit is contained in:
Lincoln Stein
2024-02-05 22:56:32 -05:00
committed by psychedelicious
parent fbded1c0f2
commit dfcf38be91
31 changed files with 727 additions and 496 deletions

View File

@ -3,13 +3,13 @@
import math
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import List, Literal, Optional, Union
from typing import Iterator, List, Literal, Optional, Tuple, Union
import einops
import numpy as np
import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers import AutoencoderKL, AutoencoderTiny, UNet2DConditionModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.adapter import T2IAdapter
from diffusers.models.attention_processor import (
@ -46,14 +46,13 @@ from invokeai.app.invocations.primitives import (
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.embeddings.model_patcher import ModelPatcher
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.model_manager import AnyModel, BaseModelType
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from invokeai.backend.util.silence_warnings import SilenceWarnings
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import BaseModelType
from ...backend.model_management.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
IPAdapterData,
@ -149,7 +148,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image is not None:
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.services.model_records.load_model(
**self.vae.vae.model_dump(),
context=context,
)
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
@ -175,7 +177,10 @@ def get_scheduler(
seed: int,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
orig_scheduler_info = context.services.model_records.load_model(
**scheduler_info.model_dump(),
context=context,
)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@ -389,10 +394,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.models.load(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context.services.model_records.load_model(
key=control_info.control_model.key,
context=context,
)
)
@ -456,17 +460,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(
model_name=single_ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=single_ip_adapter.ip_adapter_model.base_model,
context.services.model_records.load_model(
key=single_ip_adapter.ip_adapter_model.key,
context=context,
)
)
image_encoder_model_info = context.models.load(
model_name=single_ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=single_ip_adapter.image_encoder_model.base_model,
image_encoder_model_info = context.services.model_records.load_model(
key=single_ip_adapter.image_encoder_model.key,
context=context,
)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
@ -518,10 +520,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_info = context.models.load(
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
model_type=ModelType.T2IAdapter,
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
t2i_adapter_model_info = context.services.model_records.load_model(
key=t2i_adapter_field.t2i_adapter_model.key,
context=context,
)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
@ -556,7 +557,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=False,
width=t2i_input_width,
height=t2i_input_height,
num_channels=t2i_adapter_model.config.in_channels,
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
device=t2i_adapter_model.device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
@ -662,22 +663,30 @@ class DenoiseLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState):
context.util.sd_step_callback(state, self.unet.unet.base_model)
def _lora_loader():
def _lora_loader() -> Iterator[Tuple[AnyModel, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
lora_info = context.services.model_records.load_model(
**lora.model_dump(exclude={"weight"}),
context=context,
)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(**self.unet.unet.model_dump())
unet_info = context.services.model_records.load_model(
**self.unet.unet.model_dump(),
context=context,
)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
set_seamless(unet_info.context.model, self.unet.seamless_axes),
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
):
assert isinstance(unet, torch.Tensor)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
@ -774,9 +783,13 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.services.model_records.load_model(
**self.vae.vae.model_dump(),
context=context,
)
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.Tensor)
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
@ -995,7 +1008,10 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.services.model_records.load_model(
**self.vae.vae.model_dump(),
context=context,
)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3: