mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy names and locations of modules
- Rename old "model_management" directory to "model_management_OLD" in order to catch dangling references to original model manager. - Caught and fixed most dangling references (still checking) - Rename lora, textual_inversion and model_patcher modules - Introduce a RawModel base class to simplfy the Union returned by the model loaders. - Tidy up the model manager 2-related tests. Add useful fixtures, and a finalizer to the queue and installer fixtures that will stop the services and release threads.
This commit is contained in:
committed by
psychedelicious
parent
ba1f8878dd
commit
2ad0752582
@ -10,6 +10,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||
|
||||
from .resampler import Resampler
|
||||
from ..raw_model import RawModel
|
||||
|
||||
|
||||
class ImageProjModel(torch.nn.Module):
|
||||
@ -91,7 +92,7 @@ class MLPProjModel(torch.nn.Module):
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class IPAdapter:
|
||||
class IPAdapter(RawModel):
|
||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||
|
||||
def __init__(
|
||||
|
@ -10,8 +10,7 @@ from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
|
||||
from .embedding_base import EmbeddingModelRaw
|
||||
from .raw_model import RawModel
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
@ -367,9 +366,7 @@ class IA3Layer(LoRALayerBase):
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
|
||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||
class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module):
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
@ -28,12 +28,11 @@ from diffusers import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from ..raw_model import RawModel
|
||||
|
||||
from ..embeddings.embedding_base import EmbeddingModelRaw
|
||||
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
|
||||
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus, EmbeddingModelRaw]
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
|
@ -5,7 +5,7 @@ import psutil
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
|
||||
from ..util.libc_util import LibcUtil, Struct_mallinfo2
|
||||
|
||||
GB = 2**30 # 1 GB
|
||||
|
||||
@ -97,4 +97,4 @@ def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: O
|
||||
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
||||
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
|
||||
|
||||
return "\n" + msg if len(msg) > 0 else msg
|
||||
return msg
|
||||
|
@ -7,7 +7,7 @@ from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.embeddings.lora import LoRAModelRaw
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
|
@ -5,7 +5,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
|
@ -8,9 +8,7 @@ import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||
from invokeai.backend.model_management.util import lora_token_vector_length
|
||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.util.util import SilenceWarnings
|
||||
|
||||
from .config import (
|
||||
@ -55,7 +53,6 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ProbeBase(object):
|
||||
"""Base class for probes."""
|
||||
|
||||
@ -653,8 +650,8 @@ class LoRAFolderProbe(FolderProbeBase):
|
||||
|
||||
|
||||
class IPAdapterFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> IPAdapterModelFormat:
|
||||
return IPAdapterModelFormat.InvokeAI.value
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat.InvokeAI
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = self.model_path / "ip_adapter.bin"
|
||||
|
75
invokeai/backend/model_manager/util/libc_util.py
Normal file
75
invokeai/backend/model_manager/util/libc_util.py
Normal file
@ -0,0 +1,75 @@
|
||||
import ctypes
|
||||
|
||||
|
||||
class Struct_mallinfo2(ctypes.Structure):
|
||||
"""A ctypes Structure that matches the libc mallinfo2 struct.
|
||||
|
||||
Docs:
|
||||
- https://man7.org/linux/man-pages/man3/mallinfo.3.html
|
||||
- https://www.gnu.org/software/libc/manual/html_node/Statistics-of-Malloc.html
|
||||
|
||||
struct mallinfo2 {
|
||||
size_t arena; /* Non-mmapped space allocated (bytes) */
|
||||
size_t ordblks; /* Number of free chunks */
|
||||
size_t smblks; /* Number of free fastbin blocks */
|
||||
size_t hblks; /* Number of mmapped regions */
|
||||
size_t hblkhd; /* Space allocated in mmapped regions (bytes) */
|
||||
size_t usmblks; /* See below */
|
||||
size_t fsmblks; /* Space in freed fastbin blocks (bytes) */
|
||||
size_t uordblks; /* Total allocated space (bytes) */
|
||||
size_t fordblks; /* Total free space (bytes) */
|
||||
size_t keepcost; /* Top-most, releasable space (bytes) */
|
||||
};
|
||||
"""
|
||||
|
||||
_fields_ = [
|
||||
("arena", ctypes.c_size_t),
|
||||
("ordblks", ctypes.c_size_t),
|
||||
("smblks", ctypes.c_size_t),
|
||||
("hblks", ctypes.c_size_t),
|
||||
("hblkhd", ctypes.c_size_t),
|
||||
("usmblks", ctypes.c_size_t),
|
||||
("fsmblks", ctypes.c_size_t),
|
||||
("uordblks", ctypes.c_size_t),
|
||||
("fordblks", ctypes.c_size_t),
|
||||
("keepcost", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
s = ""
|
||||
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
|
||||
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
|
||||
s += f"{'smblks': <10}= {(self.smblks): >15} # Number of free fastbin blocks \n"
|
||||
s += f"{'hblks': <10}= {(self.hblks): >15} # Number of mmapped regions \n"
|
||||
s += f"{'hblkhd': <10}= {(self.hblkhd/2**30):15.5f} # Space allocated in mmapped regions (GB)\n"
|
||||
s += f"{'usmblks': <10}= {(self.usmblks): >15} # Unused\n"
|
||||
s += f"{'fsmblks': <10}= {(self.fsmblks/2**30):15.5f} # Space in freed fastbin blocks (GB)\n"
|
||||
s += (
|
||||
f"{'uordblks': <10}= {(self.uordblks/2**30):15.5f} # Space used by in-use allocations (non-mmapped)"
|
||||
" (GB)\n"
|
||||
)
|
||||
s += f"{'fordblks': <10}= {(self.fordblks/2**30):15.5f} # Space in free blocks (non-mmapped) (GB)\n"
|
||||
s += f"{'keepcost': <10}= {(self.keepcost/2**30):15.5f} # Top-most, releasable space (GB)\n"
|
||||
return s
|
||||
|
||||
|
||||
class LibcUtil:
|
||||
"""A utility class for interacting with the C Standard Library (`libc`) via ctypes.
|
||||
|
||||
Note that this class will raise on __init__() if 'libc.so.6' can't be found. Take care to handle environments where
|
||||
this shared library is not available.
|
||||
|
||||
TODO: Improve cross-OS compatibility of this class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
|
||||
|
||||
def mallinfo2(self) -> Struct_mallinfo2:
|
||||
"""Calls `libc` `mallinfo2`.
|
||||
|
||||
Docs: https://man7.org/linux/man-pages/man3/mallinfo.3.html
|
||||
"""
|
||||
mallinfo2 = self._libc.mallinfo2
|
||||
mallinfo2.restype = Struct_mallinfo2
|
||||
return mallinfo2()
|
129
invokeai/backend/model_manager/util/model_util.py
Normal file
129
invokeai/backend/model_manager/util/model_util.py
Normal file
@ -0,0 +1,129 @@
|
||||
"""Utilities for parsing model files, used mostly by probe.py"""
|
||||
|
||||
import json
|
||||
import torch
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
def _fast_safetensors_reader(path: str):
|
||||
checkpoint = {}
|
||||
device = torch.device("meta")
|
||||
with open(path, "rb") as f:
|
||||
definition_len = int.from_bytes(f.read(8), "little")
|
||||
definition_json = f.read(definition_len)
|
||||
definition = json.loads(definition_json)
|
||||
|
||||
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
|
||||
"pt",
|
||||
"torch",
|
||||
"pytorch",
|
||||
}:
|
||||
raise Exception("Supported only pytorch safetensors files")
|
||||
definition.pop("__metadata__", None)
|
||||
|
||||
for key, info in definition.items():
|
||||
dtype = {
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"I32": torch.int32,
|
||||
"I64": torch.int64,
|
||||
"F16": torch.float16,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
}[info["dtype"]]
|
||||
|
||||
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
||||
|
||||
return checkpoint
|
||||
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
checkpoint = _fast_safetensors_reader(path)
|
||||
except Exception:
|
||||
# TODO: create issue for support "meta"?
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||
else:
|
||||
if scan:
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
"""
|
||||
Given a checkpoint in memory, return the lora token vector length
|
||||
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key: str, tensor, checkpoint) -> int:
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
return lora_token_vector_length # wrong key format
|
||||
model_key, lora_key = key.split(".", 1)
|
||||
|
||||
# check lora/locon
|
||||
if lora_key == "lora_down.weight":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
||||
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
||||
elif "lokr_" in lora_key:
|
||||
if model_key + ".lokr_w1" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
|
||||
elif model_key + "lokr_w1_b" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
if model_key + ".lokr_w2" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
|
||||
elif model_key + "lokr_w2_b" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
||||
|
||||
elif lora_key == "diff":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# ia3 can be detected only by shape[0] in text encoder
|
||||
elif lora_key == "weight" and "lora_unet_" not in model_key:
|
||||
lora_token_vector_length = tensor.shape[0]
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
lora_token_vector_length = None
|
||||
lora_te1_length = None
|
||||
lora_te2_length = None
|
||||
for key, tensor in checkpoint.items():
|
||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_unet_") and (
|
||||
"time_emb_proj.lora_down" in key
|
||||
): # recognizes format at https://civitai.com/models/224641
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||
if key.startswith("lora_te_"):
|
||||
lora_token_vector_length = tmp_length
|
||||
elif key.startswith("lora_te1_"):
|
||||
lora_te1_length = tmp_length
|
||||
elif key.startswith("lora_te2_"):
|
||||
lora_te2_length = tmp_length
|
||||
|
||||
if lora_te1_length is not None and lora_te2_length is not None:
|
||||
lora_token_vector_length = lora_te1_length + lora_te2_length
|
||||
|
||||
if lora_token_vector_length is not None:
|
||||
break
|
||||
|
||||
return lora_token_vector_length
|
@ -8,6 +8,7 @@ import numpy as np
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||
from ..raw_model import RawModel
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
|
||||
|
14
invokeai/backend/raw_model.py
Normal file
14
invokeai/backend/raw_model.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""Base class for 'Raw' models.
|
||||
|
||||
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||
and is used for type checking of calls to the model patcher. Its main purpose
|
||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||
from lora.py.
|
||||
|
||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
class RawModel:
|
||||
"""Base class for 'Raw' model wrappers."""
|
@ -8,11 +8,9 @@ from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTokenizer
|
||||
from typing_extensions import Self
|
||||
from .raw_model import RawModel
|
||||
|
||||
from .embedding_base import EmbeddingModelRaw
|
||||
|
||||
|
||||
class TextualInversionModelRaw(EmbeddingModelRaw):
|
||||
class TextualInversionModelRaw(RawModel):
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
||||
|
@ -5,10 +5,9 @@ from typing import Optional, Union
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -16,31 +15,20 @@ def torch_device():
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_installer():
|
||||
"""A global ModelInstall pytest fixture to be used by many tests."""
|
||||
# HACK(ryand): InvokeAIAppConfig.get_config() returns a singleton config object. This can lead to weird interactions
|
||||
# between tests that need to alter the config. For example, some tests change the 'root' directory in the config,
|
||||
# which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround,
|
||||
# we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using
|
||||
# a singleton.
|
||||
return ModelInstall(InvokeAIAppConfig.get_config(log_level="info"))
|
||||
|
||||
|
||||
def install_and_load_model(
|
||||
model_installer: ModelInstall,
|
||||
model_manager: ModelManagerServiceBase,
|
||||
model_path_id_or_url: Union[str, Path],
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> LoadedModelInfo:
|
||||
"""Install a model if it is not already installed, then get the LoadedModelInfo for that model.
|
||||
) -> LoadedModel:
|
||||
"""Install a model if it is not already installed, then get the LoadedModel for that model.
|
||||
|
||||
This is intended as a utility function for tests.
|
||||
|
||||
Args:
|
||||
model_installer (ModelInstall): The model installer.
|
||||
mm2_model_manager (ModelManagerServiceBase): The model manager
|
||||
model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it
|
||||
is not already installed.
|
||||
model_name (str): The model name, forwarded to ModelManager.get_model(...).
|
||||
@ -51,16 +39,23 @@ def install_and_load_model(
|
||||
Returns:
|
||||
LoadedModelInfo
|
||||
"""
|
||||
# If the requested model is already installed, return its LoadedModelInfo.
|
||||
with contextlib.suppress(ModelNotFoundException):
|
||||
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
||||
# If the requested model is already installed, return its LoadedModel
|
||||
with contextlib.suppress(UnknownModelException):
|
||||
# TODO: Replace with wrapper call
|
||||
loaded_model: LoadedModel = model_manager.load.load_model_by_attr(
|
||||
model_name=model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
return loaded_model
|
||||
|
||||
# Install the requested model.
|
||||
model_installer.heuristic_import(model_path_id_or_url)
|
||||
job = model_manager.install.heuristic_import(model_path_id_or_url)
|
||||
model_manager.install.wait_for_job(job, timeout=10)
|
||||
assert job.complete
|
||||
|
||||
try:
|
||||
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
||||
except ModelNotFoundException as e:
|
||||
loaded_model = model_manager.load.load_model_by_config(job.config_out)
|
||||
return loaded_model
|
||||
except UnknownModelException as e:
|
||||
raise Exception(
|
||||
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
|
||||
f" the installation id ('{model_path_id_or_url}'). Error: {e}"
|
||||
|
Reference in New Issue
Block a user