BREAKING CHANGES: invocations now require model key, not base/type/name

- Implement new model loader and modify invocations and embeddings

- Finish implementation loaders for all models currently supported by
  InvokeAI.

- Move lora, textual_inversion, and model patching support into
  backend/embeddings.

- Restore support for model cache statistics collection (a little ugly,
  needs work).

- Fixed up invocations that load and patch models.

- Move seamless and silencewarnings utils into better location
This commit is contained in:
Lincoln Stein 2024-02-05 22:56:32 -05:00 committed by psychedelicious
parent 5745ce9c7d
commit 78ef946e01
31 changed files with 727 additions and 496 deletions

View File

@ -55,7 +55,7 @@ async def download(
) -> DownloadJob:
"""Download the source URL to the file or directory indicted in dest."""
queue = ApiDependencies.invoker.services.download_queue
return queue.download(source, dest, priority, access_token)
return queue.download(source, Path(dest), priority, access_token)
@download_queue_router.get(

View File

@ -1,9 +1,10 @@
from typing import List, Optional, Union
from typing import Iterator, List, Optional, Tuple, Union
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
import invokeai.backend.util.logging as logger
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@ -12,18 +13,21 @@ from invokeai.app.invocations.fields import (
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
from invokeai.backend.embeddings.lora import LoRAModelRaw
from invokeai.backend.embeddings.model_patcher import ModelPatcher
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager import ModelType
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
ExtraConditioningInfo,
SDXLConditioningInfo,
)
from invokeai.backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.util.devices import torch_dtype
from ..util.ti_utils import extract_ti_triggers_from_prompt
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -64,13 +68,22 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
tokenizer_info = context.services.model_records.load_model(
**self.clip.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_records.load_model(
**self.clip.text_encoder.model_dump(),
context=context,
)
def _lora_loader():
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
lora_info = context.services.model_records.load_model(
**lora.model_dump(exclude={"weight"}), context=context
)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
@ -80,24 +93,20 @@ class CompelInvocation(BaseInvocation):
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.models.load(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model,
)
)
except ModelNotFoundException:
loaded_model = context.services.model_records.load_model(
**self.clip.text_encoder.model_dump(),
context=context,
).model
assert isinstance(loaded_model, TextualInversionModelRaw)
ti_list.append((name, loaded_model))
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with (
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
tokenizer,
ti_manager,
),
@ -105,7 +114,7 @@ class CompelInvocation(BaseInvocation):
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
):
compel = Compel(
tokenizer=tokenizer,
@ -144,6 +153,8 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase:
"""Prompt processor for SDXL models."""
def run_clip_compel(
self,
context: InvocationContext,
@ -152,20 +163,27 @@ class SDXLPromptInvocationBase:
get_pooled: bool,
lora_prefix: str,
zero_on_empty: bool,
):
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.services.model_records.load_model(
**clip_field.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_records.load_model(
**clip_field.text_encoder.model_dump(),
context=context,
)
# return zero on empty
if prompt == "" and zero_on_empty:
cpu_text_encoder = text_encoder_info.context.model
cpu_text_encoder = text_encoder_info.model
assert isinstance(cpu_text_encoder, torch.nn.Module)
c = torch.zeros(
(
1,
cpu_text_encoder.config.max_position_embeddings,
cpu_text_encoder.config.hidden_size,
),
dtype=text_encoder_info.context.cache.precision,
dtype=cpu_text_encoder.dtype,
)
if get_pooled:
c_pooled = torch.zeros(
@ -176,10 +194,14 @@ class SDXLPromptInvocationBase:
c_pooled = None
return c, c_pooled, None
def _lora_loader():
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
lora_info = context.services.model_records.load_model(
**lora.model_dump(exclude={"weight"}), context=context
)
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
del lora_info
return
@ -189,24 +211,24 @@ class SDXLPromptInvocationBase:
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.models.load(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model,
)
)
except ModelNotFoundException:
ti_model = context.services.model_records.load_model_by_attr(
model_name=name,
base_model=text_encoder_info.config.base,
model_type=ModelType.TextualInversion,
context=context,
).model
assert isinstance(ti_model, TextualInversionModelRaw)
ti_list.append((name, ti_model))
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
logger.warning(f'trigger: "{trigger}" not found')
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
with (
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
tokenizer,
ti_manager,
),
@ -214,7 +236,7 @@ class SDXLPromptInvocationBase:
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers),
):
compel = Compel(
tokenizer=tokenizer,
@ -332,6 +354,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
dim=1,
)
assert c2_pooled is not None
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
@ -380,6 +403,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
assert c2_pooled is not None
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(

View File

@ -23,7 +23,7 @@ from controlnet_aux import (
)
from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.fields import (
FieldDescriptions,
@ -60,10 +60,7 @@ CONTROLNET_RESIZE_VALUES = Literal[
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Model config record key for the ControlNet model")
class ControlField(BaseModel):

View File

@ -2,7 +2,8 @@ import os
from builtins import float
from typing import List, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -18,18 +19,13 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
# LS: Consider moving these two classes into model.py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Key to the IP-Adapter model")
class CLIPVisionModelField(BaseModel):
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Key to the CLIP Vision image encoder model")
class IPAdapterField(BaseModel):
@ -46,16 +42,26 @@ class IPAdapterField(BaseModel):
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
def validate_ip_adapter_weight(cls, v: float) -> float:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
def validate_begin_end_step_percent(self) -> Self:
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def get_ip_adapter_image_encoder_model_id(model_path: str):
"""Read the ID of the image encoder associated with the IP-Adapter at `model_path`."""
image_encoder_config_file = os.path.join(model_path, "image_encoder.txt")
with open(image_encoder_config_file, "r") as f:
image_encoder_model = f.readline().strip()
return image_encoder_model
@invocation_output("ip_adapter_output")
class IPAdapterOutput(BaseInvocationOutput):
# Outputs
@ -84,33 +90,36 @@ class IPAdapterInvocation(BaseInvocation):
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
def validate_ip_adapter_weight(cls, v: float) -> float:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
def validate_begin_end_step_percent(self) -> Self:
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_info(
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
)
ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key)
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
# directly, and 2) we are reading from disk every time this invocation is called without caching the result.
# A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this
# is currently messy due to differences between how the model info is generated when installing a model from
# disk vs. downloading the model.
# TODO (LS): Fix the issue above by:
# 1. Change IPAdapterConfig definition to include a field for the repo_id of the image encoder model.
# 2. Update probe.py to read `image_encoder.txt` and store it in the config.
# 3. Change below to get the image encoder from the configuration record.
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
os.path.join(context.config.get().models_path, ip_adapter_info["path"])
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info.path)
)
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_model = CLIPVisionModelField(
model_name=image_encoder_model_name,
base_model=BaseModelType.Any,
image_encoder_models = context.services.model_records.search_by_attr(
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,

View File

@ -3,13 +3,13 @@
import math
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import List, Literal, Optional, Union
from typing import Iterator, List, Literal, Optional, Tuple, Union
import einops
import numpy as np
import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers import AutoencoderKL, AutoencoderTiny, UNet2DConditionModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.adapter import T2IAdapter
from diffusers.models.attention_processor import (
@ -46,14 +46,13 @@ from invokeai.app.invocations.primitives import (
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.embeddings.model_patcher import ModelPatcher
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.model_manager import AnyModel, BaseModelType
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from invokeai.backend.util.silence_warnings import SilenceWarnings
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import BaseModelType
from ...backend.model_management.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
IPAdapterData,
@ -149,7 +148,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image is not None:
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.services.model_records.load_model(
**self.vae.vae.model_dump(),
context=context,
)
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
@ -175,7 +177,10 @@ def get_scheduler(
seed: int,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
orig_scheduler_info = context.services.model_records.load_model(
**scheduler_info.model_dump(),
context=context,
)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@ -389,10 +394,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.models.load(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context.services.model_records.load_model(
key=control_info.control_model.key,
context=context,
)
)
@ -456,17 +460,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(
model_name=single_ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=single_ip_adapter.ip_adapter_model.base_model,
context.services.model_records.load_model(
key=single_ip_adapter.ip_adapter_model.key,
context=context,
)
)
image_encoder_model_info = context.models.load(
model_name=single_ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=single_ip_adapter.image_encoder_model.base_model,
image_encoder_model_info = context.services.model_records.load_model(
key=single_ip_adapter.image_encoder_model.key,
context=context,
)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
@ -518,10 +520,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_info = context.models.load(
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
model_type=ModelType.T2IAdapter,
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
t2i_adapter_model_info = context.services.model_records.load_model(
key=t2i_adapter_field.t2i_adapter_model.key,
context=context,
)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
@ -556,7 +557,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=False,
width=t2i_input_width,
height=t2i_input_height,
num_channels=t2i_adapter_model.config.in_channels,
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
device=t2i_adapter_model.device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
@ -662,22 +663,30 @@ class DenoiseLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState):
context.util.sd_step_callback(state, self.unet.unet.base_model)
def _lora_loader():
def _lora_loader() -> Iterator[Tuple[AnyModel, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
lora_info = context.services.model_records.load_model(
**lora.model_dump(exclude={"weight"}),
context=context,
)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(**self.unet.unet.model_dump())
unet_info = context.services.model_records.load_model(
**self.unet.unet.model_dump(),
context=context,
)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
set_seamless(unet_info.context.model, self.unet.seamless_axes),
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
):
assert isinstance(unet, torch.Tensor)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
@ -774,9 +783,13 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.services.model_records.load_model(
**self.vae.vae.model_dump(),
context=context,
)
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.Tensor)
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
@ -995,7 +1008,10 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.services.model_records.load_model(
**self.vae.vae.model_dump(),
context=context,
)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@ -1,13 +1,13 @@
import copy
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -17,13 +17,9 @@ from .baseinvocation import (
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel")
key: str = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
model_config = ConfigDict(protected_namespaces=())
class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model")
@ -52,7 +48,7 @@ class VaeField(BaseModel):
@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):
"""Base class for invocations that output a UNet field"""
"""Base class for invocations that output a UNet field."""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
@ -81,20 +77,13 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
class MainModelField(BaseModel):
"""Main model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Model key")
class LoRAModelField(BaseModel):
"""LoRA model field"""
model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="LoRA model key")
@invocation(
@ -111,74 +100,31 @@ class MainModelLoaderInvocation(BaseInvocation):
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
key = self.model.key
# TODO: not found exceptions
if not context.models.exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
if not context.services.model_records.exists(key):
raise Exception(f"Unknown model {key}")
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.TextEncoder,
),
loras=[],
@ -186,9 +132,7 @@ class MainModelLoaderInvocation(BaseInvocation):
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.Vae,
),
),
@ -226,21 +170,16 @@ class LoraLoaderInvocation(BaseInvocation):
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
lora_key = self.lora.key
if not context.models.exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unkown lora name: {lora_name}!")
if not context.services.model_records.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
output = LoraLoaderOutput()
@ -248,9 +187,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=lora_key,
submodel=None,
weight=self.weight,
)
@ -260,9 +197,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=lora_key,
submodel=None,
weight=self.weight,
)
@ -315,24 +250,19 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
lora_key = self.lora.key
if not context.models.exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unknown lora name: {lora_name}!")
if not context.services.model_records.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip2')
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip2')
output = SDXLLoraLoaderOutput()
@ -340,9 +270,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=lora_key,
submodel=None,
weight=self.weight,
)
@ -352,9 +280,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=lora_key,
submodel=None,
weight=self.weight,
)
@ -364,9 +290,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.clip2 = copy.deepcopy(self.clip2)
output.clip2.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=lora_key,
submodel=None,
weight=self.weight,
)
@ -378,10 +302,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
class VAEModelField(BaseModel):
"""Vae model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Model's key")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
@ -395,25 +316,12 @@ class VaeLoaderInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> VAEOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
key = self.vae_model.key
if not context.models.exists(
base_model=base_model,
model_name=model_name,
model_type=model_type,
):
raise Exception(f"Unkown vae name: {model_name}!")
return VAEOutput(
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
)
)
if not context.services.model_records.exists(key):
raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
@invocation_output("seamless_output")

View File

@ -1,7 +1,7 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType
from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -40,45 +40,31 @@ class SDXLModelLoaderInvocation(BaseInvocation):
# TODO: precision?
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
model_key = self.model.key
# TODO: not found exceptions
if not context.models.exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
if not context.services.model_records.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
return SDXLModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.TextEncoder,
),
loras=[],
@ -86,15 +72,11 @@ class SDXLModelLoaderInvocation(BaseInvocation):
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.TextEncoder2,
),
loras=[],
@ -102,9 +84,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Vae,
),
),
@ -129,45 +109,31 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
# TODO: precision?
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
model_key = self.model.key
# TODO: not found exceptions
if not context.models.exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
if not context.services.model_records.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.TextEncoder2,
),
loras=[],
@ -175,9 +141,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Vae,
),
),

View File

@ -1,6 +1,6 @@
from typing import Union
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -12,14 +12,10 @@ from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESI
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_management.models.base import BaseModelType
class T2IAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the T2I-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Model record key for the T2I-Adapter model")
class T2IAdapterField(BaseModel):

View File

@ -11,8 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueStatus,
)
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_management.model_manager import LoadedModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager import AnyModelConfig
class EventServiceBase:
@ -171,10 +170,7 @@ class EventServiceBase:
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_config: AnyModelConfig,
) -> None:
"""Emitted when a model is requested"""
self.__emit_queue_event(
@ -184,10 +180,7 @@ class EventServiceBase:
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_name": model_name,
"base_model": base_model,
"model_type": model_type,
"submodel": submodel,
"model_config": model_config.model_dump(),
},
)
@ -197,11 +190,7 @@ class EventServiceBase:
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
loaded_model_info: LoadedModelInfo,
model_config: AnyModelConfig,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
@ -211,13 +200,7 @@ class EventServiceBase:
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_name": model_name,
"base_model": base_model,
"model_type": model_type,
"submodel": submodel,
"hash": loaded_model_info.hash,
"location": str(loaded_model_info.location),
"precision": str(loaded_model_info.precision),
"model_config": model_config.model_dump(),
},
)

View File

@ -2,6 +2,7 @@ import json
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Iterator
import psutil
import torch
@ -10,7 +11,7 @@ import invokeai.backend.util.logging as logger
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
from invokeai.backend.model_management.model_cache import CacheStats
from invokeai.backend.model_manager.load.model_cache import CacheStats
from .invocation_stats_base import InvocationStatsServiceBase
from .invocation_stats_common import (
@ -41,7 +42,10 @@ class InvocationStatsService(InvocationStatsServiceBase):
self._invoker = invoker
@contextmanager
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str):
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]:
services = self._invoker.services
if services.model_records is None or services.model_records.loader is None:
yield None
if not self._stats.get(graph_execution_state_id):
# First time we're seeing this graph_execution_state_id.
self._stats[graph_execution_state_id] = GraphExecutionStats()
@ -55,8 +59,10 @@ class InvocationStatsService(InvocationStatsServiceBase):
start_ram = psutil.Process().memory_info().rss
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
if self._invoker.services.model_manager:
self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id])
# TO DO [LS]: clean up loader service - shouldn't be an attribute of model records
assert services.model_records.loader is not None
services.model_records.loader.ram_cache.stats = self._cache_stats[graph_execution_state_id]
try:
# Let the invocation run.
@ -73,7 +79,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def _prune_stale_stats(self):
def _prune_stale_stats(self) -> None:
"""Check all graphs being tracked and prune any that have completed/errored.
This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so

View File

@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager import (
AnyModelConfig,
@ -19,6 +20,7 @@ from invokeai.backend.model_manager import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load import AnyModelLoader
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
@ -110,12 +112,45 @@ class ModelRecordServiceBase(ABC):
pass
@abstractmethod
def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel:
def load_model(
self,
key: str,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel_type: For main (pipeline models), the submodel to fetch
:param submodel: For main (pipeline models), the submodel to fetch
:param context: Invocation context, used for event issuing.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Key of model config to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
@ -166,7 +201,7 @@ class ModelRecordServiceBase(ABC):
@abstractmethod
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
Return True if a model with the indicated key exists in the database.
:param key: Unique key for the model to be deleted
"""
@ -209,6 +244,12 @@ class ModelRecordServiceBase(ABC):
"""
pass
@property
@abstractmethod
def loader(self) -> Optional[AnyModelLoader]:
"""Return the model loader used by this instance."""
pass
def all_models(self) -> List[AnyModelConfig]:
"""Return all the model configs in the database."""
return self.search_by_attr()

View File

@ -46,6 +46,8 @@ from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
AnyModelConfig,
@ -88,6 +90,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Return the underlying database."""
return self._db
@property
def loader(self) -> Optional[AnyModelLoader]:
"""Return the model loader used by this instance."""
return self._loader
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
"""
Add a model to the database.
@ -213,20 +220,73 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel:
def load_model(
self,
key: str,
submodel: Optional[SubModelType],
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel_type: For main (pipeline models), the submodel to fetch.
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
if not self._loader:
raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader")
# we can emit model loading events if we are executing with access to the invocation context
model_config = self.get_model(key)
return self._loader.load_model(model_config, submodel_type)
if context:
self._emit_load_event(
context=context,
model_config=model_config,
)
loaded_model = self._loader.load_model(model_config, submodel)
if context:
self._emit_load_event(
context=context,
model_config=model_config,
loaded=True,
)
return loaded_model
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Key of model config to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load_model(configs[0].key, submodel)
def exists(self, key: str) -> bool:
"""
@ -416,3 +476,29 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
)
def _emit_load_event(
self,
context: InvocationContext,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
) -> None:
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
if not loaded:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)
else:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)

View File

@ -0,0 +1,4 @@
"""Initialization file for invokeai.backend.embeddings modules."""
# from .model_patcher import ModelPatcher
# __all__ = ["ModelPatcher"]

View File

@ -0,0 +1,12 @@
"""Base class for LoRA and Textual Inversion models.
The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw,
and is used for type checking of calls to the model patcher.
The use of "Raw" here is a historical artifact, and carried forward in
order to avoid confusion.
"""
class EmbeddingModelRaw:
"""Base class for LoRA and Textual Inversion models."""

View File

@ -11,6 +11,8 @@ from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from .embedding_base import EmbeddingModelRaw
class LoRALayerBase:
# rank: Optional[int]
@ -317,7 +319,7 @@ class FullLayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
@ -367,7 +369,7 @@ AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]
@ -471,16 +473,16 @@ class LoRAModelRaw: # (torch.nn.Module):
file_path = Path(file_path)
model = cls(
name=file_path.stem, # TODO:
name=file_path.stem,
layers={},
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
sd = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)

View File

@ -4,22 +4,20 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple
import numpy as np
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers import ModelMixin, OnnxRuntimeModel, UNet2DConditionModel
from safetensors.torch import load_file
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from typing_extensions import Self
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from .lora import LoRAModelRaw
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
"""
loras = [
@ -67,7 +65,7 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> Generator[None, None, None]:
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@ -76,8 +74,8 @@ class ModelPatcher:
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
):
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@ -87,7 +85,7 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
):
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@ -97,7 +95,7 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
):
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield
@ -105,10 +103,10 @@ class ModelPatcher:
@contextmanager
def apply_lora(
cls,
model: Union[torch.nn.Module, ModelMixin, UNet2DConditionModel],
loras: List[Tuple[LoRAModelRaw, float]],
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
) -> Generator[None, None, None]:
) -> None:
original_weights = {}
try:
with torch.no_grad():
@ -125,6 +123,7 @@ class ModelPatcher:
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
@ -170,8 +169,8 @@ class ModelPatcher:
cls,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
ti_list: List[Tuple[str, TextualInversionModel]],
) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]:
ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
init_tokens_count = None
new_tokens_added = None
@ -201,7 +200,7 @@ class ModelPatcher:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModel) -> torch.Tensor:
def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModelRaw) -> torch.Tensor:
# for SDXL models, select the embedding that matches the text encoder's dimensions
if ti.embedding_2 is not None:
return (
@ -229,6 +228,7 @@ class ModelPatcher:
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:
assert isinstance(ti, TextualInversionModelRaw)
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
ti_tokens = []
@ -267,7 +267,7 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
clip_skip: int,
) -> Generator[None, None, None]:
) -> None:
skipped_layers = []
try:
for _i in range(clip_skip):
@ -285,7 +285,7 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
freeu_config: Optional[FreeUConfig] = None,
) -> Generator[None, None, None]:
) -> None:
did_apply_freeu = False
try:
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
@ -301,94 +301,6 @@ class ModelPatcher:
unet.disable_freeu()
class TextualInversionModel:
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Self:
if not isinstance(file_path, Path):
file_path = Path(file_path)
result = cls() # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
# both v1 and v2 format embeddings
# difference mostly in metadata
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
" token will be used.",
)
result.embedding = next(iter(state_dict["string_to_param"].values()))
# v3 (easynegative)
elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"]
# v5(sdxl safetensors file)
elif "clip_g" in state_dict and "clip_l" in state_dict:
result.embedding = state_dict["clip_g"]
result.embedding_2 = state_dict["clip_l"]
# v4(diffusers bin files)
else:
result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}")
return result
# no type hints for BaseTextualInversionManager?
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
if len(self.pad_tokens) == 0:
return token_ids
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length:
new_token_ids = new_token_ids[0:max_length]
return new_token_ids
class ONNXModelPatcher:
@classmethod
@contextmanager
@ -396,7 +308,7 @@ class ONNXModelPatcher:
cls,
unet: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> Generator[None, None, None]:
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@ -406,7 +318,7 @@ class ONNXModelPatcher:
cls,
text_encoder: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> Generator[None, None, None]:
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@ -419,7 +331,7 @@ class ONNXModelPatcher:
model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
prefix: str,
) -> Generator[None, None, None]:
) -> None:
from .models.base import IAIOnnxRuntimeModel
if not isinstance(model, IAIOnnxRuntimeModel):
@ -506,7 +418,7 @@ class ONNXModelPatcher:
tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Tuple[str, Any]],
) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]:
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
from .models.base import IAIOnnxRuntimeModel
if not isinstance(text_encoder, IAIOnnxRuntimeModel):

View File

@ -0,0 +1,100 @@
"""Textual Inversion wrapper class."""
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from safetensors.torch import load_file
from transformers import CLIPTokenizer
from typing_extensions import Self
from .embedding_base import EmbeddingModelRaw
class TextualInversionModelRaw(EmbeddingModelRaw):
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Self:
if not isinstance(file_path, Path):
file_path = Path(file_path)
result = cls() # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
# both v1 and v2 format embeddings
# difference mostly in metadata
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
" token will be used.",
)
result.embedding = next(iter(state_dict["string_to_param"].values()))
# v3 (easynegative)
elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"]
# v5(sdxl safetensors file)
elif "clip_g" in state_dict and "clip_l" in state_dict:
result.embedding = state_dict["clip_g"]
result.embedding_2 = state_dict["clip_l"]
# v4(diffusers bin files)
else:
result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}")
return result
# no type hints for BaseTextualInversionManager?
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
if len(self.pad_tokens) == 0:
return token_ids
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length:
new_token_ids = new_token_ids[0:max_length]
return new_token_ids

View File

@ -241,10 +241,11 @@ class InstallHelper(object):
if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url):
repo_id = match.group(1)
repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None
subfolder = Path(model_info.subfolder) if model_info.subfolder else None
return HFModelSource(
repo_id=repo_id,
access_token=HfFolder.get_token(),
subfolder=model_info.subfolder,
subfolder=subfolder,
variant=repo_variant,
)
if re.match(r"^(http|https):", model_path_id_or_url):

View File

@ -30,8 +30,11 @@ from typing_extensions import Annotated, Any, Dict
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from ..embeddings.embedding_base import EmbeddingModelRaw
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus, EmbeddingModelRaw]
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
@ -299,7 +302,7 @@ AnyModelConfig = Union[
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus]
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown

View File

@ -18,8 +18,8 @@ from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.config import AnyModel, VaeCheckpointConfig, VaeDiffusersConfig
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.util.logging import InvokeAILogger

View File

@ -19,7 +19,7 @@ from invokeai.backend.model_manager import (
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats, ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
@ -71,7 +71,7 @@ class ModelLoader(ModelLoaderBase):
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
if not model_path.exists():
raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}")
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
locker = self._load_if_needed(model_config, model_path, submodel_type)

View File

@ -1,4 +1,6 @@
"""Init file for ModelCache."""
from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
from .model_cache_default import ModelCache # noqa F401
_all__ = ["ModelCacheBase", "ModelCache"]
_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]

View File

@ -8,13 +8,13 @@ model will be cleared and (re)loaded from disk when next needed.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from logging import Logger
from typing import Generic, Optional, TypeVar
from typing import Dict, Generic, Optional, TypeVar
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.config import AnyModel, SubModelType
class ModelLockerBase(ABC):
@ -65,6 +65,19 @@ class CacheRecord(Generic[T]):
return self._locks > 0
@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelCacheBase(ABC, Generic[T]):
"""Virtual base class for RAM model cache."""
@ -98,10 +111,22 @@ class ModelCacheBase(ABC, Generic[T]):
pass
@abstractmethod
def move_model_to_device(self, cache_entry: CacheRecord, device: torch.device) -> None:
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], device: torch.device) -> None:
"""Move model into the indicated device."""
pass
@property
@abstractmethod
def stats(self) -> CacheStats:
"""Return collected CacheStats object."""
pass
@stats.setter
@abstractmethod
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collectin cache statistics."""
pass
@property
@abstractmethod
def logger(self) -> Logger:

View File

@ -24,19 +24,17 @@ import math
import sys
import time
from contextlib import suppress
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, List, Optional
import torch
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModel
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, ModelCacheBase
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase
from .model_locker import ModelLocker, ModelLockerBase
if choose_torch_device() == torch.device("mps"):
@ -56,20 +54,6 @@ GIG = 1073741824
MB = 2**20
@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
@ -110,7 +94,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
# used for stats collection
self.stats = CacheStats()
self._stats: Optional[CacheStats] = None
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
@ -140,6 +124,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""Return the cap on cache size."""
return self._max_cache_size
@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
return self._stats
@stats.setter
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collectin cache statistics."""
self._stats = stats
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""
total = 0
@ -189,21 +183,24 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
self.stats.hits += 1
if self.stats:
self.stats.hits += 1
else:
self.stats.misses += 1
if self.stats:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
# more stats
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# this moves the entry to the top (right end) of the stack
with suppress(Exception):

View File

@ -5,7 +5,7 @@
from pathlib import Path
from typing import Optional, Tuple
from invokeai.backend.embeddings.model_patcher import TextualInversionModel as TextualInversionModelRaw
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,

View File

@ -4,3 +4,12 @@ Initialization file for the invokeai.backend.stable_diffusion package
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .seamless import set_seamless # noqa: F401
__all__ = [
"PipelineIntermediateState",
"StableDiffusionGeneratorPipeline",
"InvokeAIDiffuserComponent",
"AttentionMapSaver",
"set_seamless",
]

View File

@ -0,0 +1,102 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import List, Union
import torch.nn as nn
from diffusers.models import AutoencoderKL, UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
return nn.functional.conv2d(
working,
weight,
bias,
self.stride,
nn.modules.utils._pair(0),
self.dilation,
self.groups,
)
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
try:
to_restore = []
for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel):
if ".attentions." in m_name:
continue
if ".resnets." in m_name:
if ".conv2" in m_name:
continue
if ".conv_shortcut" in m_name:
continue
"""
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
continue
if False and ".downsamplers." in m_name:
continue
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield
finally:
for module, orig_conv_forward in to_restore:
module._conv_forward = orig_conv_forward
if hasattr(module, "asymmetric_padding_mode"):
del module.asymmetric_padding_mode
if hasattr(module, "asymmetric_padding"):
del module.asymmetric_padding

View File

@ -0,0 +1,28 @@
"""Context class to silence transformers and diffusers warnings."""
import warnings
from typing import Any
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
class SilenceWarnings(object):
"""Use in context to temporarily turn off warnings from transformers & diffusers modules.
with SilenceWarnings():
# do something
"""
def __init__(self) -> None:
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self) -> None:
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, *args: Any) -> None:
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")

View File

@ -23,7 +23,7 @@ import torch
from npyscreen import widget
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo
from invokeai.backend.model_manager import ModelType
from invokeai.backend.util import choose_precision, choose_torch_device
@ -499,7 +499,7 @@ class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore
)
def list_models(installer: ModelInstallService, model_type: ModelType):
def list_models(installer: ModelInstallServiceBase, model_type: ModelType):
"""Print out all models of type model_type."""
models = installer.record_store.search_by_attr(model_type=model_type)
print(f"Installed models of type `{model_type}`:")
@ -527,7 +527,9 @@ def select_and_download_models(opt: Namespace) -> None:
install_helper.add_or_delete(selections)
elif opt.default_only:
selections = InstallSelections(install_models=[install_helper.default_model()])
default_model = install_helper.default_model()
assert default_model is not None
selections = InstallSelections(install_models=[default_model])
install_helper.add_or_delete(selections)
elif opt.yes_to_all:

View File

@ -192,6 +192,7 @@ def sdxl_base_files() -> List[Path]:
"text_encoder/model.onnx",
"text_encoder_2/config.json",
"text_encoder_2/model.onnx",
"text_encoder_2/model.onnx_data",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
@ -202,6 +203,7 @@ def sdxl_base_files() -> List[Path]:
"tokenizer_2/vocab.json",
"unet/config.json",
"unet/model.onnx",
"unet/model.onnx_data",
"vae_decoder/config.json",
"vae_decoder/model.onnx",
"vae_encoder/config.json",

View File

@ -2,7 +2,7 @@ from pathlib import Path
import pytest
from invokeai.backend import BaseModelType
from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
from invokeai.backend.model_manager.probe import VaeFolderProbe
@ -21,10 +21,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat
base_type = probe.get_base_type()
assert base_type == expected_type
repo_variant = probe.get_repo_variant()
assert repo_variant == "default"
assert repo_variant == ModelRepoVariant.DEFAULT
def test_repo_variant(datadir: Path):
probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16")
repo_variant = probe.get_repo_variant()
assert repo_variant == "fp16"
assert repo_variant == ModelRepoVariant.FP16