Tidy names and locations of modules

- Rename old "model_management" directory to "model_management_OLD" in order to catch
  dangling references to original model manager.
- Caught and fixed most dangling references (still checking)
- Rename lora, textual_inversion and model_patcher modules
- Introduce a RawModel base class to simplfy the Union returned by the
  model loaders.
- Tidy up the model manager 2-related tests. Add useful fixtures, and
  a finalizer to the queue and installer fixtures that will stop the
  services and release threads.
This commit is contained in:
Lincoln Stein
2024-02-17 11:45:32 -05:00
committed by psychedelicious
parent ba1f8878dd
commit 2ad0752582
89 changed files with 355 additions and 1609 deletions

View File

@ -10,6 +10,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from .resampler import Resampler
from ..raw_model import RawModel
class ImageProjModel(torch.nn.Module):
@ -91,7 +92,7 @@ class MLPProjModel(torch.nn.Module):
return clip_extra_context_tokens
class IPAdapter:
class IPAdapter(RawModel):
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
def __init__(

View File

@ -10,8 +10,7 @@ from safetensors.torch import load_file
from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from .embedding_base import EmbeddingModelRaw
from .raw_model import RawModel
class LoRALayerBase:
@ -367,9 +366,7 @@ class IA3Layer(LoRALayerBase):
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module):
class LoRAModelRaw(RawModel): # (torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]

View File

@ -28,12 +28,11 @@ from diffusers import ModelMixin
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from ..raw_model import RawModel
from ..embeddings.embedding_base import EmbeddingModelRaw
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus, EmbeddingModelRaw]
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
class InvalidModelConfigException(Exception):

View File

@ -5,7 +5,7 @@ import psutil
import torch
from typing_extensions import Self
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
from ..util.libc_util import LibcUtil, Struct_mallinfo2
GB = 2**30 # 1 GB
@ -97,4 +97,4 @@ def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: O
if snapshot_1.vram is not None and snapshot_2.vram is not None:
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
return "\n" + msg if len(msg) > 0 else msg
return msg

View File

@ -7,7 +7,7 @@ from pathlib import Path
from typing import Optional, Tuple
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.embeddings.lora import LoRAModelRaw
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,

View File

@ -5,7 +5,7 @@
from pathlib import Path
from typing import Optional, Tuple
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,

View File

@ -8,9 +8,7 @@ import torch
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.models.base import read_checkpoint_meta
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from invokeai.backend.model_management.util import lora_token_vector_length
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.util.util import SilenceWarnings
from .config import (
@ -55,7 +53,6 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
},
}
class ProbeBase(object):
"""Base class for probes."""
@ -653,8 +650,8 @@ class LoRAFolderProbe(FolderProbeBase):
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> IPAdapterModelFormat:
return IPAdapterModelFormat.InvokeAI.value
def get_format(self) -> ModelFormat:
return ModelFormat.InvokeAI
def get_base_type(self) -> BaseModelType:
model_file = self.model_path / "ip_adapter.bin"

View File

@ -0,0 +1,75 @@
import ctypes
class Struct_mallinfo2(ctypes.Structure):
"""A ctypes Structure that matches the libc mallinfo2 struct.
Docs:
- https://man7.org/linux/man-pages/man3/mallinfo.3.html
- https://www.gnu.org/software/libc/manual/html_node/Statistics-of-Malloc.html
struct mallinfo2 {
size_t arena; /* Non-mmapped space allocated (bytes) */
size_t ordblks; /* Number of free chunks */
size_t smblks; /* Number of free fastbin blocks */
size_t hblks; /* Number of mmapped regions */
size_t hblkhd; /* Space allocated in mmapped regions (bytes) */
size_t usmblks; /* See below */
size_t fsmblks; /* Space in freed fastbin blocks (bytes) */
size_t uordblks; /* Total allocated space (bytes) */
size_t fordblks; /* Total free space (bytes) */
size_t keepcost; /* Top-most, releasable space (bytes) */
};
"""
_fields_ = [
("arena", ctypes.c_size_t),
("ordblks", ctypes.c_size_t),
("smblks", ctypes.c_size_t),
("hblks", ctypes.c_size_t),
("hblkhd", ctypes.c_size_t),
("usmblks", ctypes.c_size_t),
("fsmblks", ctypes.c_size_t),
("uordblks", ctypes.c_size_t),
("fordblks", ctypes.c_size_t),
("keepcost", ctypes.c_size_t),
]
def __str__(self):
s = ""
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
s += f"{'smblks': <10}= {(self.smblks): >15} # Number of free fastbin blocks \n"
s += f"{'hblks': <10}= {(self.hblks): >15} # Number of mmapped regions \n"
s += f"{'hblkhd': <10}= {(self.hblkhd/2**30):15.5f} # Space allocated in mmapped regions (GB)\n"
s += f"{'usmblks': <10}= {(self.usmblks): >15} # Unused\n"
s += f"{'fsmblks': <10}= {(self.fsmblks/2**30):15.5f} # Space in freed fastbin blocks (GB)\n"
s += (
f"{'uordblks': <10}= {(self.uordblks/2**30):15.5f} # Space used by in-use allocations (non-mmapped)"
" (GB)\n"
)
s += f"{'fordblks': <10}= {(self.fordblks/2**30):15.5f} # Space in free blocks (non-mmapped) (GB)\n"
s += f"{'keepcost': <10}= {(self.keepcost/2**30):15.5f} # Top-most, releasable space (GB)\n"
return s
class LibcUtil:
"""A utility class for interacting with the C Standard Library (`libc`) via ctypes.
Note that this class will raise on __init__() if 'libc.so.6' can't be found. Take care to handle environments where
this shared library is not available.
TODO: Improve cross-OS compatibility of this class.
"""
def __init__(self):
self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
def mallinfo2(self) -> Struct_mallinfo2:
"""Calls `libc` `mallinfo2`.
Docs: https://man7.org/linux/man-pages/man3/mallinfo.3.html
"""
mallinfo2 = self._libc.mallinfo2
mallinfo2.restype = Struct_mallinfo2
return mallinfo2()

View File

@ -0,0 +1,129 @@
"""Utilities for parsing model files, used mostly by probe.py"""
import json
import torch
from typing import Union
from pathlib import Path
from picklescan.scanner import scan_file_path
def _fast_safetensors_reader(path: str):
checkpoint = {}
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), "little")
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
"pt",
"torch",
"pytorch",
}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(path)
except Exception:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")
else:
if scan:
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
def lora_token_vector_length(checkpoint: dict) -> int:
"""
Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint
"""
def _get_shape_1(key: str, tensor, checkpoint) -> int:
lora_token_vector_length = None
if "." not in key:
return lora_token_vector_length # wrong key format
model_key, lora_key = key.split(".", 1)
# check lora/locon
if lora_key == "lora_down.weight":
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif "lokr_" in lora_key:
if model_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
elif model_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format
if model_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
elif model_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif lora_key == "diff":
lora_token_vector_length = tensor.shape[1]
# ia3 can be detected only by shape[0] in text encoder
elif lora_key == "weight" and "lora_unet_" not in model_key:
lora_token_vector_length = tensor.shape[0]
return lora_token_vector_length
lora_token_vector_length = None
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_unet_") and (
"time_emb_proj.lora_down" in key
): # recognizes format at https://civitai.com/models/224641
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):
lora_token_vector_length = tmp_length
elif key.startswith("lora_te1_"):
lora_te1_length = tmp_length
elif key.startswith("lora_te2_"):
lora_te2_length = tmp_length
if lora_te1_length is not None and lora_te2_length is not None:
lora_token_vector_length = lora_te1_length + lora_te2_length
if lora_token_vector_length is not None:
break
return lora_token_vector_length

View File

@ -8,6 +8,7 @@ import numpy as np
import onnx
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from ..raw_model import RawModel
ONNX_WEIGHTS_NAME = "model.onnx"

View File

@ -0,0 +1,14 @@
"""Base class for 'Raw' models.
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
and is used for type checking of calls to the model patcher. Its main purpose
is to avoid a circular import issues when lora.py tries to import BaseModelType
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
from lora.py.
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes.
"""
class RawModel:
"""Base class for 'Raw' model wrappers."""

View File

@ -8,11 +8,9 @@ from compel.embeddings_provider import BaseTextualInversionManager
from safetensors.torch import load_file
from transformers import CLIPTokenizer
from typing_extensions import Self
from .raw_model import RawModel
from .embedding_base import EmbeddingModelRaw
class TextualInversionModelRaw(EmbeddingModelRaw):
class TextualInversionModelRaw(RawModel):
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models

View File

@ -5,10 +5,9 @@ from typing import Optional, Union
import pytest
import torch
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.backend.model_management.model_manager import LoadedModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType
@pytest.fixture(scope="session")
@ -16,31 +15,20 @@ def torch_device():
return "cuda" if torch.cuda.is_available() else "cpu"
@pytest.fixture(scope="module")
def model_installer():
"""A global ModelInstall pytest fixture to be used by many tests."""
# HACK(ryand): InvokeAIAppConfig.get_config() returns a singleton config object. This can lead to weird interactions
# between tests that need to alter the config. For example, some tests change the 'root' directory in the config,
# which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround,
# we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using
# a singleton.
return ModelInstall(InvokeAIAppConfig.get_config(log_level="info"))
def install_and_load_model(
model_installer: ModelInstall,
model_manager: ModelManagerServiceBase,
model_path_id_or_url: Union[str, Path],
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
) -> LoadedModelInfo:
"""Install a model if it is not already installed, then get the LoadedModelInfo for that model.
) -> LoadedModel:
"""Install a model if it is not already installed, then get the LoadedModel for that model.
This is intended as a utility function for tests.
Args:
model_installer (ModelInstall): The model installer.
mm2_model_manager (ModelManagerServiceBase): The model manager
model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it
is not already installed.
model_name (str): The model name, forwarded to ModelManager.get_model(...).
@ -51,16 +39,23 @@ def install_and_load_model(
Returns:
LoadedModelInfo
"""
# If the requested model is already installed, return its LoadedModelInfo.
with contextlib.suppress(ModelNotFoundException):
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
# If the requested model is already installed, return its LoadedModel
with contextlib.suppress(UnknownModelException):
# TODO: Replace with wrapper call
loaded_model: LoadedModel = model_manager.load.load_model_by_attr(
model_name=model_name, base_model=base_model, model_type=model_type
)
return loaded_model
# Install the requested model.
model_installer.heuristic_import(model_path_id_or_url)
job = model_manager.install.heuristic_import(model_path_id_or_url)
model_manager.install.wait_for_job(job, timeout=10)
assert job.complete
try:
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
except ModelNotFoundException as e:
loaded_model = model_manager.load.load_model_by_config(job.config_out)
return loaded_model
except UnknownModelException as e:
raise Exception(
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
f" the installation id ('{model_path_id_or_url}'). Error: {e}"