Tidy names and locations of modules

- Rename old "model_management" directory to "model_management_OLD" in order to catch
  dangling references to original model manager.
- Caught and fixed most dangling references (still checking)
- Rename lora, textual_inversion and model_patcher modules
- Introduce a RawModel base class to simplfy the Union returned by the
  model loaders.
- Tidy up the model manager 2-related tests. Add useful fixtures, and
  a finalizer to the queue and installer fixtures that will stop the
  services and release threads.
This commit is contained in:
Lincoln Stein
2024-02-17 11:45:32 -05:00
committed by psychedelicious
parent 996eb96b4e
commit 5d612ec095
89 changed files with 355 additions and 1609 deletions

View File

@ -28,12 +28,11 @@ from diffusers import ModelMixin
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from ..raw_model import RawModel
from ..embeddings.embedding_base import EmbeddingModelRaw
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus, EmbeddingModelRaw]
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
class InvalidModelConfigException(Exception):

View File

@ -0,0 +1,75 @@
import ctypes
class Struct_mallinfo2(ctypes.Structure):
"""A ctypes Structure that matches the libc mallinfo2 struct.
Docs:
- https://man7.org/linux/man-pages/man3/mallinfo.3.html
- https://www.gnu.org/software/libc/manual/html_node/Statistics-of-Malloc.html
struct mallinfo2 {
size_t arena; /* Non-mmapped space allocated (bytes) */
size_t ordblks; /* Number of free chunks */
size_t smblks; /* Number of free fastbin blocks */
size_t hblks; /* Number of mmapped regions */
size_t hblkhd; /* Space allocated in mmapped regions (bytes) */
size_t usmblks; /* See below */
size_t fsmblks; /* Space in freed fastbin blocks (bytes) */
size_t uordblks; /* Total allocated space (bytes) */
size_t fordblks; /* Total free space (bytes) */
size_t keepcost; /* Top-most, releasable space (bytes) */
};
"""
_fields_ = [
("arena", ctypes.c_size_t),
("ordblks", ctypes.c_size_t),
("smblks", ctypes.c_size_t),
("hblks", ctypes.c_size_t),
("hblkhd", ctypes.c_size_t),
("usmblks", ctypes.c_size_t),
("fsmblks", ctypes.c_size_t),
("uordblks", ctypes.c_size_t),
("fordblks", ctypes.c_size_t),
("keepcost", ctypes.c_size_t),
]
def __str__(self):
s = ""
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
s += f"{'smblks': <10}= {(self.smblks): >15} # Number of free fastbin blocks \n"
s += f"{'hblks': <10}= {(self.hblks): >15} # Number of mmapped regions \n"
s += f"{'hblkhd': <10}= {(self.hblkhd/2**30):15.5f} # Space allocated in mmapped regions (GB)\n"
s += f"{'usmblks': <10}= {(self.usmblks): >15} # Unused\n"
s += f"{'fsmblks': <10}= {(self.fsmblks/2**30):15.5f} # Space in freed fastbin blocks (GB)\n"
s += (
f"{'uordblks': <10}= {(self.uordblks/2**30):15.5f} # Space used by in-use allocations (non-mmapped)"
" (GB)\n"
)
s += f"{'fordblks': <10}= {(self.fordblks/2**30):15.5f} # Space in free blocks (non-mmapped) (GB)\n"
s += f"{'keepcost': <10}= {(self.keepcost/2**30):15.5f} # Top-most, releasable space (GB)\n"
return s
class LibcUtil:
"""A utility class for interacting with the C Standard Library (`libc`) via ctypes.
Note that this class will raise on __init__() if 'libc.so.6' can't be found. Take care to handle environments where
this shared library is not available.
TODO: Improve cross-OS compatibility of this class.
"""
def __init__(self):
self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
def mallinfo2(self) -> Struct_mallinfo2:
"""Calls `libc` `mallinfo2`.
Docs: https://man7.org/linux/man-pages/man3/mallinfo.3.html
"""
mallinfo2 = self._libc.mallinfo2
mallinfo2.restype = Struct_mallinfo2
return mallinfo2()

View File

@ -5,7 +5,7 @@ import psutil
import torch
from typing_extensions import Self
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
from ..util.libc_util import LibcUtil, Struct_mallinfo2
GB = 2**30 # 1 GB
@ -97,4 +97,4 @@ def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: O
if snapshot_1.vram is not None and snapshot_2.vram is not None:
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
return "\n" + msg if len(msg) > 0 else msg
return msg

View File

@ -7,7 +7,7 @@ from pathlib import Path
from typing import Optional, Tuple
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.embeddings.lora import LoRAModelRaw
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,

View File

@ -5,7 +5,7 @@
from pathlib import Path
from typing import Optional, Tuple
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,

View File

@ -8,9 +8,7 @@ import torch
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.models.base import read_checkpoint_meta
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from invokeai.backend.model_management.util import lora_token_vector_length
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.util.util import SilenceWarnings
from .config import (
@ -55,7 +53,6 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
},
}
class ProbeBase(object):
"""Base class for probes."""
@ -653,8 +650,8 @@ class LoRAFolderProbe(FolderProbeBase):
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> IPAdapterModelFormat:
return IPAdapterModelFormat.InvokeAI.value
def get_format(self) -> ModelFormat:
return ModelFormat.InvokeAI
def get_base_type(self) -> BaseModelType:
model_file = self.model_path / "ip_adapter.bin"

View File

@ -0,0 +1,75 @@
import ctypes
class Struct_mallinfo2(ctypes.Structure):
"""A ctypes Structure that matches the libc mallinfo2 struct.
Docs:
- https://man7.org/linux/man-pages/man3/mallinfo.3.html
- https://www.gnu.org/software/libc/manual/html_node/Statistics-of-Malloc.html
struct mallinfo2 {
size_t arena; /* Non-mmapped space allocated (bytes) */
size_t ordblks; /* Number of free chunks */
size_t smblks; /* Number of free fastbin blocks */
size_t hblks; /* Number of mmapped regions */
size_t hblkhd; /* Space allocated in mmapped regions (bytes) */
size_t usmblks; /* See below */
size_t fsmblks; /* Space in freed fastbin blocks (bytes) */
size_t uordblks; /* Total allocated space (bytes) */
size_t fordblks; /* Total free space (bytes) */
size_t keepcost; /* Top-most, releasable space (bytes) */
};
"""
_fields_ = [
("arena", ctypes.c_size_t),
("ordblks", ctypes.c_size_t),
("smblks", ctypes.c_size_t),
("hblks", ctypes.c_size_t),
("hblkhd", ctypes.c_size_t),
("usmblks", ctypes.c_size_t),
("fsmblks", ctypes.c_size_t),
("uordblks", ctypes.c_size_t),
("fordblks", ctypes.c_size_t),
("keepcost", ctypes.c_size_t),
]
def __str__(self):
s = ""
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
s += f"{'smblks': <10}= {(self.smblks): >15} # Number of free fastbin blocks \n"
s += f"{'hblks': <10}= {(self.hblks): >15} # Number of mmapped regions \n"
s += f"{'hblkhd': <10}= {(self.hblkhd/2**30):15.5f} # Space allocated in mmapped regions (GB)\n"
s += f"{'usmblks': <10}= {(self.usmblks): >15} # Unused\n"
s += f"{'fsmblks': <10}= {(self.fsmblks/2**30):15.5f} # Space in freed fastbin blocks (GB)\n"
s += (
f"{'uordblks': <10}= {(self.uordblks/2**30):15.5f} # Space used by in-use allocations (non-mmapped)"
" (GB)\n"
)
s += f"{'fordblks': <10}= {(self.fordblks/2**30):15.5f} # Space in free blocks (non-mmapped) (GB)\n"
s += f"{'keepcost': <10}= {(self.keepcost/2**30):15.5f} # Top-most, releasable space (GB)\n"
return s
class LibcUtil:
"""A utility class for interacting with the C Standard Library (`libc`) via ctypes.
Note that this class will raise on __init__() if 'libc.so.6' can't be found. Take care to handle environments where
this shared library is not available.
TODO: Improve cross-OS compatibility of this class.
"""
def __init__(self):
self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
def mallinfo2(self) -> Struct_mallinfo2:
"""Calls `libc` `mallinfo2`.
Docs: https://man7.org/linux/man-pages/man3/mallinfo.3.html
"""
mallinfo2 = self._libc.mallinfo2
mallinfo2.restype = Struct_mallinfo2
return mallinfo2()

View File

@ -0,0 +1,129 @@
"""Utilities for parsing model files, used mostly by probe.py"""
import json
import torch
from typing import Union
from pathlib import Path
from picklescan.scanner import scan_file_path
def _fast_safetensors_reader(path: str):
checkpoint = {}
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), "little")
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
"pt",
"torch",
"pytorch",
}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(path)
except Exception:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")
else:
if scan:
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
def lora_token_vector_length(checkpoint: dict) -> int:
"""
Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint
"""
def _get_shape_1(key: str, tensor, checkpoint) -> int:
lora_token_vector_length = None
if "." not in key:
return lora_token_vector_length # wrong key format
model_key, lora_key = key.split(".", 1)
# check lora/locon
if lora_key == "lora_down.weight":
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif "lokr_" in lora_key:
if model_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
elif model_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format
if model_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
elif model_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif lora_key == "diff":
lora_token_vector_length = tensor.shape[1]
# ia3 can be detected only by shape[0] in text encoder
elif lora_key == "weight" and "lora_unet_" not in model_key:
lora_token_vector_length = tensor.shape[0]
return lora_token_vector_length
lora_token_vector_length = None
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_unet_") and (
"time_emb_proj.lora_down" in key
): # recognizes format at https://civitai.com/models/224641
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):
lora_token_vector_length = tmp_length
elif key.startswith("lora_te1_"):
lora_te1_length = tmp_length
elif key.startswith("lora_te2_"):
lora_te2_length = tmp_length
if lora_te1_length is not None and lora_te2_length is not None:
lora_token_vector_length = lora_te1_length + lora_te2_length
if lora_token_vector_length is not None:
break
return lora_token_vector_length