mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make model manager v2 ready for PR review
- Replace legacy model manager service with the v2 manager. - Update invocations to use new load interface. - Fixed many but not all type checking errors in the invocations. Most were unrelated to model manager - Updated routes. All the new routes live under the route tag `model_manager_v2`. To avoid confusion with the old routes, they have the URL prefix `/api/v2/models`. The old routes have been de-registered. - Added a pytest for the loader. - Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
committed by
psychedelicious
parent
7956602b19
commit
a23dedd2ee
@ -3,13 +3,15 @@
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
from typing import Iterator, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny, UNet2DConditionModel
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.adapter import T2IAdapter
|
||||
from diffusers.models.attention_processor import (
|
||||
@ -18,8 +20,10 @@ from diffusers.models.attention_processor import (
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from PIL import Image
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
@ -46,9 +50,10 @@ from invokeai.app.invocations.primitives import (
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.embeddings.lora import LoRAModelRaw
|
||||
from invokeai.backend.embeddings.model_patcher import ModelPatcher
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.model_manager import AnyModel, BaseModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
@ -123,10 +128,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
ui_order=4,
|
||||
)
|
||||
|
||||
def prep_mask_tensor(self, mask_image):
|
||||
def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor:
|
||||
if mask_image.mode != "L":
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
if mask_tensor.dim() == 3:
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
# if shape is not None:
|
||||
@ -136,25 +141,25 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
if self.image is not None:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
else:
|
||||
image = None
|
||||
image_tensor = None
|
||||
|
||||
mask = self.prep_mask_tensor(
|
||||
context.images.get_pil(self.mask.image_name),
|
||||
)
|
||||
|
||||
if image is not None:
|
||||
vae_info = context.services.model_records.load_model(
|
||||
if image_tensor is not None:
|
||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
# TODO:
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||
|
||||
@ -177,7 +182,7 @@ def get_scheduler(
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_records.load_model(
|
||||
orig_scheduler_info = context.services.model_manager.load.load_model_by_key(
|
||||
**scheduler_info.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@ -188,7 +193,7 @@ def get_scheduler(
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {
|
||||
**scheduler_config,
|
||||
**scheduler_extra_config,
|
||||
**scheduler_extra_config, # FIXME
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
|
||||
@ -201,6 +206,7 @@ def get_scheduler(
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, "uses_inpainting_model"):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
assert isinstance(scheduler, Scheduler)
|
||||
return scheduler
|
||||
|
||||
|
||||
@ -284,7 +290,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
@ -298,9 +304,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
unet,
|
||||
seed,
|
||||
scheduler: Scheduler,
|
||||
unet: UNet2DConditionModel,
|
||||
seed: int,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
@ -323,7 +329,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
||||
scheduler,
|
||||
# for ddim scheduler
|
||||
eta=0.0, # ddim_eta
|
||||
@ -335,8 +341,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def create_pipeline(
|
||||
self,
|
||||
unet,
|
||||
scheduler,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
# configure_model_padding(
|
||||
@ -347,10 +353,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
return StableDiffusionGeneratorPipeline(
|
||||
@ -367,11 +373,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
control_input: Union[ControlField, List[ControlField]],
|
||||
control_input: Optional[Union[ControlField, List[ControlField]]],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
) -> Optional[List[ControlNetData]]:
|
||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
|
||||
@ -394,7 +400,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_records.load_model(
|
||||
context.services.model_manager.load.load_model_by_key(
|
||||
key=control_info.control_model.key,
|
||||
context=context,
|
||||
)
|
||||
@ -460,23 +466,25 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.services.model_records.load_model(
|
||||
context.services.model_manager.load.load_model_by_key(
|
||||
key=single_ip_adapter.ip_adapter_model.key,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.services.model_records.load_model(
|
||||
image_encoder_model_info = context.services.model_manager.load.load_model_by_key(
|
||||
key=single_ip_adapter.image_encoder_model.key,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_images = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_images, list):
|
||||
single_ipa_images = [single_ipa_images]
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
single_ipa_image_fields = [single_ipa_image_fields]
|
||||
|
||||
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images]
|
||||
single_ipa_images = [
|
||||
context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields
|
||||
]
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
@ -520,21 +528,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_info = context.services.model_records.load_model(
|
||||
t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key(
|
||||
key=t2i_adapter_field.t2i_adapter_model.key,
|
||||
context=context,
|
||||
)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
||||
if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1:
|
||||
max_unet_downscale = 8
|
||||
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL:
|
||||
elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL:
|
||||
max_unet_downscale = 4
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
|
||||
)
|
||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.")
|
||||
|
||||
t2i_adapter_model: T2IAdapter
|
||||
with t2i_adapter_model_info as t2i_adapter_model:
|
||||
@ -582,7 +588,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
||||
def init_scheduler(
|
||||
self,
|
||||
scheduler: Union[Scheduler, ConfigMixin],
|
||||
device: torch.device,
|
||||
steps: int,
|
||||
denoising_start: float,
|
||||
denoising_end: float,
|
||||
) -> Tuple[int, List[int], int]:
|
||||
assert isinstance(scheduler, ConfigMixin)
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
scheduler.set_timesteps(steps, device="cpu")
|
||||
timesteps = scheduler.timesteps.to(device=device)
|
||||
@ -594,11 +608,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
_timesteps = timesteps[:: scheduler.order]
|
||||
|
||||
# get start timestep index
|
||||
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
|
||||
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
|
||||
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
|
||||
|
||||
# get end timestep index
|
||||
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
|
||||
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
|
||||
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
|
||||
|
||||
# apply order to indexes
|
||||
@ -611,7 +625,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
|
||||
def prep_inpaint_mask(self, context: InvocationContext, latents):
|
||||
def prep_inpaint_mask(
|
||||
self, context: InvocationContext, latents: torch.Tensor
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if self.denoise_mask is None:
|
||||
return None, None
|
||||
|
||||
@ -660,12 +676,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
context.util.sd_step_callback(state, self.unet.unet.base_model)
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[AnyModel, float]]:
|
||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||
unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump())
|
||||
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
self.dispatch_progress(context, source_node_id, state, unet_config.base)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_records.load_model(
|
||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
@ -673,7 +696,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_records.load_model(
|
||||
unet_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@ -783,7 +806,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_records.load_model(
|
||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@ -961,8 +984,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info, upcast, tiled, image_tensor):
|
||||
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
orig_dtype = vae.dtype
|
||||
if upcast:
|
||||
vae.to(dtype=torch.float32)
|
||||
@ -1008,7 +1032,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.services.model_records.load_model(
|
||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@ -1026,14 +1050,19 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
@singledispatchmethod
|
||||
@staticmethod
|
||||
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
latents: torch.Tensor = image_tensor_dist.sample().to(
|
||||
dtype=vae.dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
return latents
|
||||
|
||||
@_encode_to_tensor.register
|
||||
@staticmethod
|
||||
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return vae.encode(image_tensor).latents
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents: torch.FloatTensor = vae.encode(image_tensor).latents
|
||||
return latents
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -1066,7 +1095,12 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
|
||||
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
def slerp(
|
||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||
v0: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
v1: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
DOT_THRESHOLD: float = 0.9995,
|
||||
) -> Union[torch.Tensor, npt.NDArray[Any]]:
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
@ -1099,12 +1133,16 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(device)
|
||||
|
||||
return v2
|
||||
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
|
||||
return v2_torch
|
||||
else:
|
||||
assert isinstance(v2, np.ndarray)
|
||||
return v2
|
||||
|
||||
# blend
|
||||
blended_latents = slerp(self.alpha, latents_a, latents_b)
|
||||
bl = slerp(self.alpha, latents_a, latents_b)
|
||||
assert isinstance(bl, torch.Tensor)
|
||||
blended_latents: torch.Tensor = bl # for type checking convenience
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
blended_latents = blended_latents.to("cpu")
|
||||
@ -1197,15 +1235,19 @@ class IdealSizeInvocation(BaseInvocation):
|
||||
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)",
|
||||
)
|
||||
|
||||
def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR):
|
||||
def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||
unet_config = context.services.model_manager.load.load_model_by_key(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
aspect = self.width / self.height
|
||||
dimension = 512
|
||||
if self.unet.unet.base_model == BaseModelType.StableDiffusion2:
|
||||
dimension: float = 512
|
||||
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||
dimension = 768
|
||||
elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL:
|
||||
elif unet_config.base == BaseModelType.StableDiffusionXL:
|
||||
dimension = 1024
|
||||
dimension = dimension * self.multiplier
|
||||
min_dimension = math.floor(dimension * 0.5)
|
||||
|
Reference in New Issue
Block a user