mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy names and locations of modules
- Rename old "model_management" directory to "model_management_OLD" in order to catch dangling references to original model manager. - Caught and fixed most dangling references (still checking) - Rename lora, textual_inversion and model_patcher modules - Introduce a RawModel base class to simplfy the Union returned by the model loaders. - Tidy up the model manager 2-related tests. Add useful fixtures, and a finalizer to the queue and installer fixtures that will stop the services and release threads.
This commit is contained in:
parent
09e7d35b55
commit
ed2d9ae0d9
@ -11,9 +11,9 @@ from invokeai.app.invocations.primitives import ConditioningField, ConditioningO
|
|||||||
from invokeai.app.services.model_records import UnknownModelException
|
from invokeai.app.services.model_records import UnknownModelException
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.shared.fields import FieldDescriptions
|
||||||
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
||||||
from invokeai.backend.embeddings.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.embeddings.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
|
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||||
from invokeai.backend.model_manager import ModelType
|
from invokeai.backend.model_manager import ModelType
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
|
@ -42,8 +42,8 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
|||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.shared.fields import FieldDescriptions
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.embeddings.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.embeddings.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
|
@ -15,7 +15,7 @@ from invokeai.app.invocations.primitives import ConditioningField, ConditioningO
|
|||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.shared.fields import FieldDescriptions
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.embeddings.model_patcher import ONNXModelPatcher
|
from invokeai.backend.model_patcher import ONNXModelPatcher
|
||||||
from invokeai.backend.model_manager import ModelType, SubModelType
|
from invokeai.backend.model_manager import ModelType, SubModelType
|
||||||
|
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
@ -20,11 +20,11 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
"""Wrapper around AnyModelLoader."""
|
"""Wrapper around AnyModelLoader."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
record_store: ModelRecordServiceBase,
|
record_store: ModelRecordServiceBase,
|
||||||
ram_cache: Optional[ModelCacheBase[AnyModel]] = None,
|
ram_cache: ModelCacheBase[AnyModel],
|
||||||
convert_cache: Optional[ModelConvertCacheBase] = None,
|
convert_cache: ModelConvertCacheBase,
|
||||||
):
|
):
|
||||||
"""Initialize the model load service."""
|
"""Initialize the model load service."""
|
||||||
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
@ -33,17 +33,8 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
self._any_loader = AnyModelLoader(
|
self._any_loader = AnyModelLoader(
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
ram_cache=ram_cache
|
ram_cache=ram_cache,
|
||||||
or ModelCache(
|
convert_cache=convert_cache,
|
||||||
max_cache_size=app_config.ram_cache_size,
|
|
||||||
max_vram_cache_size=app_config.vram_cache_size,
|
|
||||||
logger=logger,
|
|
||||||
),
|
|
||||||
convert_cache=convert_cache
|
|
||||||
or ModelConvertCache(
|
|
||||||
cache_path=app_config.models_convert_cache_path,
|
|
||||||
max_size=app_config.convert_cache_size,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -3,9 +3,10 @@
|
|||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||||
from invokeai.backend.model_manager.load import LoadedModel
|
from invokeai.backend.model_manager.load import LoadedModel
|
||||||
|
|
||||||
from .model_manager_default import ModelManagerService
|
from .model_manager_default import ModelManagerServiceBase, ModelManagerService
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"ModelManagerServiceBase",
|
||||||
"ModelManagerService",
|
"ModelManagerService",
|
||||||
"AnyModel",
|
"AnyModel",
|
||||||
"AnyModelConfig",
|
"AnyModelConfig",
|
||||||
|
@ -3,7 +3,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
||||||
|
|
||||||
from ...backend.model_management.models import BaseModelType
|
from ...backend.model_manager import BaseModelType
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
|
@ -10,6 +10,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
|||||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||||
|
|
||||||
from .resampler import Resampler
|
from .resampler import Resampler
|
||||||
|
from ..raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
class ImageProjModel(torch.nn.Module):
|
class ImageProjModel(torch.nn.Module):
|
||||||
@ -91,7 +92,7 @@ class MLPProjModel(torch.nn.Module):
|
|||||||
return clip_extra_context_tokens
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
|
|
||||||
class IPAdapter:
|
class IPAdapter(RawModel):
|
||||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -10,8 +10,7 @@ from safetensors.torch import load_file
|
|||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
|
from .raw_model import RawModel
|
||||||
from .embedding_base import EmbeddingModelRaw
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerBase:
|
class LoRALayerBase:
|
||||||
@ -367,9 +366,7 @@ class IA3Layer(LoRALayerBase):
|
|||||||
|
|
||||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||||
|
|
||||||
|
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
|
||||||
class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module):
|
|
||||||
_name: str
|
_name: str
|
||||||
layers: Dict[str, AnyLoRALayer]
|
layers: Dict[str, AnyLoRALayer]
|
||||||
|
|
@ -28,12 +28,11 @@ from diffusers import ModelMixin
|
|||||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||||
from typing_extensions import Annotated, Any, Dict
|
from typing_extensions import Annotated, Any, Dict
|
||||||
|
|
||||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
from ..raw_model import RawModel
|
||||||
|
|
||||||
from ..embeddings.embedding_base import EmbeddingModelRaw
|
# ModelMixin is the base class for all diffusers and transformers models
|
||||||
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||||
|
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
||||||
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus, EmbeddingModelRaw]
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidModelConfigException(Exception):
|
class InvalidModelConfigException(Exception):
|
||||||
|
@ -5,7 +5,7 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
|
from ..util.libc_util import LibcUtil, Struct_mallinfo2
|
||||||
|
|
||||||
GB = 2**30 # 1 GB
|
GB = 2**30 # 1 GB
|
||||||
|
|
||||||
@ -97,4 +97,4 @@ def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: O
|
|||||||
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
||||||
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
|
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
|
||||||
|
|
||||||
return "\n" + msg if len(msg) > 0 else msg
|
return msg
|
||||||
|
@ -7,7 +7,7 @@ from pathlib import Path
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.embeddings.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
|
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
|
@ -8,9 +8,7 @@ import torch
|
|||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
|
||||||
from invokeai.backend.model_management.util import lora_token_vector_length
|
|
||||||
from invokeai.backend.util.util import SilenceWarnings
|
from invokeai.backend.util.util import SilenceWarnings
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
@ -55,7 +53,6 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ProbeBase(object):
|
class ProbeBase(object):
|
||||||
"""Base class for probes."""
|
"""Base class for probes."""
|
||||||
|
|
||||||
@ -653,8 +650,8 @@ class LoRAFolderProbe(FolderProbeBase):
|
|||||||
|
|
||||||
|
|
||||||
class IPAdapterFolderProbe(FolderProbeBase):
|
class IPAdapterFolderProbe(FolderProbeBase):
|
||||||
def get_format(self) -> IPAdapterModelFormat:
|
def get_format(self) -> ModelFormat:
|
||||||
return IPAdapterModelFormat.InvokeAI.value
|
return ModelFormat.InvokeAI
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
model_file = self.model_path / "ip_adapter.bin"
|
model_file = self.model_path / "ip_adapter.bin"
|
||||||
|
75
invokeai/backend/model_manager/util/libc_util.py
Normal file
75
invokeai/backend/model_manager/util/libc_util.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import ctypes
|
||||||
|
|
||||||
|
|
||||||
|
class Struct_mallinfo2(ctypes.Structure):
|
||||||
|
"""A ctypes Structure that matches the libc mallinfo2 struct.
|
||||||
|
|
||||||
|
Docs:
|
||||||
|
- https://man7.org/linux/man-pages/man3/mallinfo.3.html
|
||||||
|
- https://www.gnu.org/software/libc/manual/html_node/Statistics-of-Malloc.html
|
||||||
|
|
||||||
|
struct mallinfo2 {
|
||||||
|
size_t arena; /* Non-mmapped space allocated (bytes) */
|
||||||
|
size_t ordblks; /* Number of free chunks */
|
||||||
|
size_t smblks; /* Number of free fastbin blocks */
|
||||||
|
size_t hblks; /* Number of mmapped regions */
|
||||||
|
size_t hblkhd; /* Space allocated in mmapped regions (bytes) */
|
||||||
|
size_t usmblks; /* See below */
|
||||||
|
size_t fsmblks; /* Space in freed fastbin blocks (bytes) */
|
||||||
|
size_t uordblks; /* Total allocated space (bytes) */
|
||||||
|
size_t fordblks; /* Total free space (bytes) */
|
||||||
|
size_t keepcost; /* Top-most, releasable space (bytes) */
|
||||||
|
};
|
||||||
|
"""
|
||||||
|
|
||||||
|
_fields_ = [
|
||||||
|
("arena", ctypes.c_size_t),
|
||||||
|
("ordblks", ctypes.c_size_t),
|
||||||
|
("smblks", ctypes.c_size_t),
|
||||||
|
("hblks", ctypes.c_size_t),
|
||||||
|
("hblkhd", ctypes.c_size_t),
|
||||||
|
("usmblks", ctypes.c_size_t),
|
||||||
|
("fsmblks", ctypes.c_size_t),
|
||||||
|
("uordblks", ctypes.c_size_t),
|
||||||
|
("fordblks", ctypes.c_size_t),
|
||||||
|
("keepcost", ctypes.c_size_t),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
s = ""
|
||||||
|
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
|
||||||
|
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
|
||||||
|
s += f"{'smblks': <10}= {(self.smblks): >15} # Number of free fastbin blocks \n"
|
||||||
|
s += f"{'hblks': <10}= {(self.hblks): >15} # Number of mmapped regions \n"
|
||||||
|
s += f"{'hblkhd': <10}= {(self.hblkhd/2**30):15.5f} # Space allocated in mmapped regions (GB)\n"
|
||||||
|
s += f"{'usmblks': <10}= {(self.usmblks): >15} # Unused\n"
|
||||||
|
s += f"{'fsmblks': <10}= {(self.fsmblks/2**30):15.5f} # Space in freed fastbin blocks (GB)\n"
|
||||||
|
s += (
|
||||||
|
f"{'uordblks': <10}= {(self.uordblks/2**30):15.5f} # Space used by in-use allocations (non-mmapped)"
|
||||||
|
" (GB)\n"
|
||||||
|
)
|
||||||
|
s += f"{'fordblks': <10}= {(self.fordblks/2**30):15.5f} # Space in free blocks (non-mmapped) (GB)\n"
|
||||||
|
s += f"{'keepcost': <10}= {(self.keepcost/2**30):15.5f} # Top-most, releasable space (GB)\n"
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class LibcUtil:
|
||||||
|
"""A utility class for interacting with the C Standard Library (`libc`) via ctypes.
|
||||||
|
|
||||||
|
Note that this class will raise on __init__() if 'libc.so.6' can't be found. Take care to handle environments where
|
||||||
|
this shared library is not available.
|
||||||
|
|
||||||
|
TODO: Improve cross-OS compatibility of this class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
|
||||||
|
|
||||||
|
def mallinfo2(self) -> Struct_mallinfo2:
|
||||||
|
"""Calls `libc` `mallinfo2`.
|
||||||
|
|
||||||
|
Docs: https://man7.org/linux/man-pages/man3/mallinfo.3.html
|
||||||
|
"""
|
||||||
|
mallinfo2 = self._libc.mallinfo2
|
||||||
|
mallinfo2.restype = Struct_mallinfo2
|
||||||
|
return mallinfo2()
|
129
invokeai/backend/model_manager/util/model_util.py
Normal file
129
invokeai/backend/model_manager/util/model_util.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
"""Utilities for parsing model files, used mostly by probe.py"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from typing import Union
|
||||||
|
from pathlib import Path
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
|
def _fast_safetensors_reader(path: str):
|
||||||
|
checkpoint = {}
|
||||||
|
device = torch.device("meta")
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
definition_len = int.from_bytes(f.read(8), "little")
|
||||||
|
definition_json = f.read(definition_len)
|
||||||
|
definition = json.loads(definition_json)
|
||||||
|
|
||||||
|
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
|
||||||
|
"pt",
|
||||||
|
"torch",
|
||||||
|
"pytorch",
|
||||||
|
}:
|
||||||
|
raise Exception("Supported only pytorch safetensors files")
|
||||||
|
definition.pop("__metadata__", None)
|
||||||
|
|
||||||
|
for key, info in definition.items():
|
||||||
|
dtype = {
|
||||||
|
"I8": torch.int8,
|
||||||
|
"I16": torch.int16,
|
||||||
|
"I32": torch.int32,
|
||||||
|
"I64": torch.int64,
|
||||||
|
"F16": torch.float16,
|
||||||
|
"F32": torch.float32,
|
||||||
|
"F64": torch.float64,
|
||||||
|
}[info["dtype"]]
|
||||||
|
|
||||||
|
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||||
|
if str(path).endswith(".safetensors"):
|
||||||
|
try:
|
||||||
|
checkpoint = _fast_safetensors_reader(path)
|
||||||
|
except Exception:
|
||||||
|
# TODO: create issue for support "meta"?
|
||||||
|
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||||
|
else:
|
||||||
|
if scan:
|
||||||
|
scan_result = scan_file_path(path)
|
||||||
|
if scan_result.infected_files != 0:
|
||||||
|
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||||
|
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
def lora_token_vector_length(checkpoint: dict) -> int:
|
||||||
|
"""
|
||||||
|
Given a checkpoint in memory, return the lora token vector length
|
||||||
|
|
||||||
|
:param checkpoint: The checkpoint
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_shape_1(key: str, tensor, checkpoint) -> int:
|
||||||
|
lora_token_vector_length = None
|
||||||
|
|
||||||
|
if "." not in key:
|
||||||
|
return lora_token_vector_length # wrong key format
|
||||||
|
model_key, lora_key = key.split(".", 1)
|
||||||
|
|
||||||
|
# check lora/locon
|
||||||
|
if lora_key == "lora_down.weight":
|
||||||
|
lora_token_vector_length = tensor.shape[1]
|
||||||
|
|
||||||
|
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
||||||
|
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
|
||||||
|
lora_token_vector_length = tensor.shape[1]
|
||||||
|
|
||||||
|
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
||||||
|
elif "lokr_" in lora_key:
|
||||||
|
if model_key + ".lokr_w1" in checkpoint:
|
||||||
|
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
|
||||||
|
elif model_key + "lokr_w1_b" in checkpoint:
|
||||||
|
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
|
||||||
|
else:
|
||||||
|
return lora_token_vector_length # unknown format
|
||||||
|
|
||||||
|
if model_key + ".lokr_w2" in checkpoint:
|
||||||
|
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
|
||||||
|
elif model_key + "lokr_w2_b" in checkpoint:
|
||||||
|
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
|
||||||
|
else:
|
||||||
|
return lora_token_vector_length # unknown format
|
||||||
|
|
||||||
|
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
||||||
|
|
||||||
|
elif lora_key == "diff":
|
||||||
|
lora_token_vector_length = tensor.shape[1]
|
||||||
|
|
||||||
|
# ia3 can be detected only by shape[0] in text encoder
|
||||||
|
elif lora_key == "weight" and "lora_unet_" not in model_key:
|
||||||
|
lora_token_vector_length = tensor.shape[0]
|
||||||
|
|
||||||
|
return lora_token_vector_length
|
||||||
|
|
||||||
|
lora_token_vector_length = None
|
||||||
|
lora_te1_length = None
|
||||||
|
lora_te2_length = None
|
||||||
|
for key, tensor in checkpoint.items():
|
||||||
|
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||||
|
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||||
|
elif key.startswith("lora_unet_") and (
|
||||||
|
"time_emb_proj.lora_down" in key
|
||||||
|
): # recognizes format at https://civitai.com/models/224641
|
||||||
|
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||||
|
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||||
|
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||||
|
if key.startswith("lora_te_"):
|
||||||
|
lora_token_vector_length = tmp_length
|
||||||
|
elif key.startswith("lora_te1_"):
|
||||||
|
lora_te1_length = tmp_length
|
||||||
|
elif key.startswith("lora_te2_"):
|
||||||
|
lora_te2_length = tmp_length
|
||||||
|
|
||||||
|
if lora_te1_length is not None and lora_te2_length is not None:
|
||||||
|
lora_token_vector_length = lora_te1_length + lora_te2_length
|
||||||
|
|
||||||
|
if lora_token_vector_length is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
return lora_token_vector_length
|
@ -8,6 +8,7 @@ import numpy as np
|
|||||||
import onnx
|
import onnx
|
||||||
from onnx import numpy_helper
|
from onnx import numpy_helper
|
||||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||||
|
from ..raw_model import RawModel
|
||||||
|
|
||||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||||
|
|
||||||
|
14
invokeai/backend/raw_model.py
Normal file
14
invokeai/backend/raw_model.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
"""Base class for 'Raw' models.
|
||||||
|
|
||||||
|
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||||
|
and is used for type checking of calls to the model patcher. Its main purpose
|
||||||
|
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||||
|
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||||
|
from lora.py.
|
||||||
|
|
||||||
|
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||||
|
that adds additional methods and attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class RawModel:
|
||||||
|
"""Base class for 'Raw' model wrappers."""
|
@ -8,11 +8,9 @@ from compel.embeddings_provider import BaseTextualInversionManager
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
from .raw_model import RawModel
|
||||||
|
|
||||||
from .embedding_base import EmbeddingModelRaw
|
class TextualInversionModelRaw(RawModel):
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionModelRaw(EmbeddingModelRaw):
|
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
||||||
|
|
@ -5,10 +5,9 @@ from typing import Optional, Union
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.model_records import UnknownModelException
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType, LoadedModel
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -16,31 +15,20 @@ def torch_device():
|
|||||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def model_installer():
|
|
||||||
"""A global ModelInstall pytest fixture to be used by many tests."""
|
|
||||||
# HACK(ryand): InvokeAIAppConfig.get_config() returns a singleton config object. This can lead to weird interactions
|
|
||||||
# between tests that need to alter the config. For example, some tests change the 'root' directory in the config,
|
|
||||||
# which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround,
|
|
||||||
# we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using
|
|
||||||
# a singleton.
|
|
||||||
return ModelInstall(InvokeAIAppConfig.get_config(log_level="info"))
|
|
||||||
|
|
||||||
|
|
||||||
def install_and_load_model(
|
def install_and_load_model(
|
||||||
model_installer: ModelInstall,
|
model_manager: ModelManagerServiceBase,
|
||||||
model_path_id_or_url: Union[str, Path],
|
model_path_id_or_url: Union[str, Path],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> ModelInfo:
|
) -> LoadedModel:
|
||||||
"""Install a model if it is not already installed, then get the ModelInfo for that model.
|
"""Install a model if it is not already installed, then get the LoadedModel for that model.
|
||||||
|
|
||||||
This is intended as a utility function for tests.
|
This is intended as a utility function for tests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_installer (ModelInstall): The model installer.
|
mm2_model_manager (ModelManagerServiceBase): The model manager
|
||||||
model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it
|
model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it
|
||||||
is not already installed.
|
is not already installed.
|
||||||
model_name (str): The model name, forwarded to ModelManager.get_model(...).
|
model_name (str): The model name, forwarded to ModelManager.get_model(...).
|
||||||
@ -51,16 +39,21 @@ def install_and_load_model(
|
|||||||
Returns:
|
Returns:
|
||||||
ModelInfo
|
ModelInfo
|
||||||
"""
|
"""
|
||||||
# If the requested model is already installed, return its ModelInfo.
|
# If the requested model is already installed, return its LoadedModel
|
||||||
with contextlib.suppress(ModelNotFoundException):
|
with contextlib.suppress(UnknownModelException):
|
||||||
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
# TODO: Replace with wrapper call
|
||||||
|
loaded_model: LoadedModel = model_manager.load.load_model_by_attr(name=model_name, base=base_model, type=model_type)
|
||||||
|
return loaded_model
|
||||||
|
|
||||||
# Install the requested model.
|
# Install the requested model.
|
||||||
model_installer.heuristic_import(model_path_id_or_url)
|
job = model_manager.install.heuristic_import(model_path_id_or_url)
|
||||||
|
model_manager.install.wait_for_job(job, timeout=10)
|
||||||
|
assert job.is_complete
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
loaded_model = model_manager.load.load_by_config(job.config)
|
||||||
except ModelNotFoundException as e:
|
return loaded_model
|
||||||
|
except UnknownModelException as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
|
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
|
||||||
f" the installation id ('{model_path_id_or_url}'). Error: {e}"
|
f" the installation id ('{model_path_id_or_url}'). Error: {e}"
|
||||||
|
@ -1,153 +0,0 @@
|
|||||||
# This file predefines a few models that the user may want to install.
|
|
||||||
sd-1/main/stable-diffusion-v1-5:
|
|
||||||
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
|
|
||||||
repo_id: runwayml/stable-diffusion-v1-5
|
|
||||||
recommended: True
|
|
||||||
default: True
|
|
||||||
sd-1/main/stable-diffusion-v1-5-inpainting:
|
|
||||||
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
|
|
||||||
repo_id: runwayml/stable-diffusion-inpainting
|
|
||||||
recommended: True
|
|
||||||
sd-2/main/stable-diffusion-2-1:
|
|
||||||
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
|
|
||||||
repo_id: stabilityai/stable-diffusion-2-1
|
|
||||||
recommended: False
|
|
||||||
sd-2/main/stable-diffusion-2-inpainting:
|
|
||||||
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
|
|
||||||
repo_id: stabilityai/stable-diffusion-2-inpainting
|
|
||||||
recommended: False
|
|
||||||
sdxl/main/stable-diffusion-xl-base-1-0:
|
|
||||||
description: Stable Diffusion XL base model (12 GB)
|
|
||||||
repo_id: stabilityai/stable-diffusion-xl-base-1.0
|
|
||||||
recommended: True
|
|
||||||
sdxl-refiner/main/stable-diffusion-xl-refiner-1-0:
|
|
||||||
description: Stable Diffusion XL refiner model (12 GB)
|
|
||||||
repo_id: stabilityai/stable-diffusion-xl-refiner-1.0
|
|
||||||
recommended: False
|
|
||||||
sdxl/vae/sdxl-1-0-vae-fix:
|
|
||||||
description: Fine tuned version of the SDXL-1.0 VAE
|
|
||||||
repo_id: madebyollin/sdxl-vae-fp16-fix
|
|
||||||
recommended: True
|
|
||||||
sd-1/main/Analog-Diffusion:
|
|
||||||
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
|
|
||||||
repo_id: wavymulder/Analog-Diffusion
|
|
||||||
recommended: False
|
|
||||||
sd-1/main/Deliberate_v5:
|
|
||||||
description: Versatile model that produces detailed images up to 768px (4.27 GB)
|
|
||||||
path: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors
|
|
||||||
recommended: False
|
|
||||||
sd-1/main/Dungeons-and-Diffusion:
|
|
||||||
description: Dungeons & Dragons characters (2.13 GB)
|
|
||||||
repo_id: 0xJustin/Dungeons-and-Diffusion
|
|
||||||
recommended: False
|
|
||||||
sd-1/main/dreamlike-photoreal-2:
|
|
||||||
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
|
|
||||||
repo_id: dreamlike-art/dreamlike-photoreal-2.0
|
|
||||||
recommended: False
|
|
||||||
sd-1/main/Inkpunk-Diffusion:
|
|
||||||
description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
|
|
||||||
repo_id: Envvi/Inkpunk-Diffusion
|
|
||||||
recommended: False
|
|
||||||
sd-1/main/openjourney:
|
|
||||||
description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
|
|
||||||
repo_id: prompthero/openjourney
|
|
||||||
recommended: False
|
|
||||||
sd-1/main/seek.art_MEGA:
|
|
||||||
repo_id: coreco/seek.art_MEGA
|
|
||||||
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
|
|
||||||
recommended: False
|
|
||||||
sd-1/main/trinart_stable_diffusion_v2:
|
|
||||||
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
|
|
||||||
repo_id: naclbit/trinart_stable_diffusion_v2
|
|
||||||
recommended: False
|
|
||||||
sd-1/controlnet/qrcode_monster:
|
|
||||||
repo_id: monster-labs/control_v1p_sd15_qrcode_monster
|
|
||||||
subfolder: v2
|
|
||||||
sd-1/controlnet/canny:
|
|
||||||
repo_id: lllyasviel/control_v11p_sd15_canny
|
|
||||||
recommended: True
|
|
||||||
sd-1/controlnet/inpaint:
|
|
||||||
repo_id: lllyasviel/control_v11p_sd15_inpaint
|
|
||||||
sd-1/controlnet/mlsd:
|
|
||||||
repo_id: lllyasviel/control_v11p_sd15_mlsd
|
|
||||||
sd-1/controlnet/depth:
|
|
||||||
repo_id: lllyasviel/control_v11f1p_sd15_depth
|
|
||||||
recommended: True
|
|
||||||
sd-1/controlnet/normal_bae:
|
|
||||||
repo_id: lllyasviel/control_v11p_sd15_normalbae
|
|
||||||
sd-1/controlnet/seg:
|
|
||||||
repo_id: lllyasviel/control_v11p_sd15_seg
|
|
||||||
sd-1/controlnet/lineart:
|
|
||||||
repo_id: lllyasviel/control_v11p_sd15_lineart
|
|
||||||
recommended: True
|
|
||||||
sd-1/controlnet/lineart_anime:
|
|
||||||
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
|
|
||||||
sd-1/controlnet/openpose:
|
|
||||||
repo_id: lllyasviel/control_v11p_sd15_openpose
|
|
||||||
recommended: True
|
|
||||||
sd-1/controlnet/scribble:
|
|
||||||
repo_id: lllyasviel/control_v11p_sd15_scribble
|
|
||||||
recommended: False
|
|
||||||
sd-1/controlnet/softedge:
|
|
||||||
repo_id: lllyasviel/control_v11p_sd15_softedge
|
|
||||||
sd-1/controlnet/shuffle:
|
|
||||||
repo_id: lllyasviel/control_v11e_sd15_shuffle
|
|
||||||
sd-1/controlnet/tile:
|
|
||||||
repo_id: lllyasviel/control_v11f1e_sd15_tile
|
|
||||||
sd-1/controlnet/ip2p:
|
|
||||||
repo_id: lllyasviel/control_v11e_sd15_ip2p
|
|
||||||
sd-1/t2i_adapter/canny-sd15:
|
|
||||||
repo_id: TencentARC/t2iadapter_canny_sd15v2
|
|
||||||
sd-1/t2i_adapter/sketch-sd15:
|
|
||||||
repo_id: TencentARC/t2iadapter_sketch_sd15v2
|
|
||||||
sd-1/t2i_adapter/depth-sd15:
|
|
||||||
repo_id: TencentARC/t2iadapter_depth_sd15v2
|
|
||||||
sd-1/t2i_adapter/zoedepth-sd15:
|
|
||||||
repo_id: TencentARC/t2iadapter_zoedepth_sd15v1
|
|
||||||
sdxl/t2i_adapter/canny-sdxl:
|
|
||||||
repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0
|
|
||||||
sdxl/t2i_adapter/zoedepth-sdxl:
|
|
||||||
repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0
|
|
||||||
sdxl/t2i_adapter/lineart-sdxl:
|
|
||||||
repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0
|
|
||||||
sdxl/t2i_adapter/sketch-sdxl:
|
|
||||||
repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0
|
|
||||||
sd-1/embedding/EasyNegative:
|
|
||||||
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
|
||||||
recommended: True
|
|
||||||
sd-1/embedding/ahx-beta-453407d:
|
|
||||||
repo_id: sd-concepts-library/ahx-beta-453407d
|
|
||||||
sd-1/lora/Ink scenery:
|
|
||||||
path: https://civitai.com/api/download/models/83390
|
|
||||||
sd-1/ip_adapter/ip_adapter_sd15:
|
|
||||||
repo_id: InvokeAI/ip_adapter_sd15
|
|
||||||
recommended: True
|
|
||||||
requires:
|
|
||||||
- InvokeAI/ip_adapter_sd_image_encoder
|
|
||||||
description: IP-Adapter for SD 1.5 models
|
|
||||||
sd-1/ip_adapter/ip_adapter_plus_sd15:
|
|
||||||
repo_id: InvokeAI/ip_adapter_plus_sd15
|
|
||||||
recommended: False
|
|
||||||
requires:
|
|
||||||
- InvokeAI/ip_adapter_sd_image_encoder
|
|
||||||
description: Refined IP-Adapter for SD 1.5 models
|
|
||||||
sd-1/ip_adapter/ip_adapter_plus_face_sd15:
|
|
||||||
repo_id: InvokeAI/ip_adapter_plus_face_sd15
|
|
||||||
recommended: False
|
|
||||||
requires:
|
|
||||||
- InvokeAI/ip_adapter_sd_image_encoder
|
|
||||||
description: Refined IP-Adapter for SD 1.5 models, adapted for faces
|
|
||||||
sdxl/ip_adapter/ip_adapter_sdxl:
|
|
||||||
repo_id: InvokeAI/ip_adapter_sdxl
|
|
||||||
recommended: False
|
|
||||||
requires:
|
|
||||||
- InvokeAI/ip_adapter_sdxl_image_encoder
|
|
||||||
description: IP-Adapter for SDXL models
|
|
||||||
any/clip_vision/ip_adapter_sd_image_encoder:
|
|
||||||
repo_id: InvokeAI/ip_adapter_sd_image_encoder
|
|
||||||
recommended: False
|
|
||||||
description: Required model for using IP-Adapters with SD-1/2 models
|
|
||||||
any/clip_vision/ip_adapter_sdxl_image_encoder:
|
|
||||||
repo_id: InvokeAI/ip_adapter_sdxl_image_encoder
|
|
||||||
recommended: False
|
|
||||||
description: Required model for using IP-Adapters with SDXL models
|
|
@ -1,47 +0,0 @@
|
|||||||
# This file describes the alternative machine learning models
|
|
||||||
# available to InvokeAI script.
|
|
||||||
#
|
|
||||||
# To add a new model, follow the examples below. Each
|
|
||||||
# model requires a model config file, a weights file,
|
|
||||||
# and the width and height of the images it
|
|
||||||
# was trained on.
|
|
||||||
diffusers-1.4:
|
|
||||||
description: 🤗🧨 Stable Diffusion v1.4
|
|
||||||
format: diffusers
|
|
||||||
repo_id: CompVis/stable-diffusion-v1-4
|
|
||||||
diffusers-1.5:
|
|
||||||
description: 🤗🧨 Stable Diffusion v1.5
|
|
||||||
format: diffusers
|
|
||||||
repo_id: runwayml/stable-diffusion-v1-5
|
|
||||||
default: true
|
|
||||||
diffusers-1.5+mse:
|
|
||||||
description: 🤗🧨 Stable Diffusion v1.5 + MSE-finetuned VAE
|
|
||||||
format: diffusers
|
|
||||||
repo_id: runwayml/stable-diffusion-v1-5
|
|
||||||
vae:
|
|
||||||
repo_id: stabilityai/sd-vae-ft-mse
|
|
||||||
diffusers-inpainting-1.5:
|
|
||||||
description: 🤗🧨 inpainting for Stable Diffusion v1.5
|
|
||||||
format: diffusers
|
|
||||||
repo_id: runwayml/stable-diffusion-inpainting
|
|
||||||
stable-diffusion-1.5:
|
|
||||||
description: The newest Stable Diffusion version 1.5 weight file (4.27 GB)
|
|
||||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
|
||||||
width: 512
|
|
||||||
height: 512
|
|
||||||
vae: ./models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
|
||||||
stable-diffusion-1.4:
|
|
||||||
description: Stable Diffusion inference model version 1.4
|
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
|
||||||
weights: models/ldm/stable-diffusion-v1/sd-v1-4.ckpt
|
|
||||||
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
|
||||||
width: 512
|
|
||||||
height: 512
|
|
||||||
inpainting-1.5:
|
|
||||||
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
|
||||||
config: configs/stable-diffusion/v1-inpainting-inference.yaml
|
|
||||||
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
|
||||||
description: RunwayML SD 1.5 model optimized for inpainting
|
|
||||||
width: 512
|
|
||||||
height: 512
|
|
@ -1,845 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
|
||||||
# Before running stable-diffusion on an internet-isolated machine,
|
|
||||||
# run this script from one with internet connectivity. The
|
|
||||||
# two machines must share a common .cache directory.
|
|
||||||
|
|
||||||
"""
|
|
||||||
This is the npyscreen frontend to the model installation application.
|
|
||||||
The work is actually done in backend code in model_install_backend.py.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import curses
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import textwrap
|
|
||||||
import traceback
|
|
||||||
from argparse import Namespace
|
|
||||||
from multiprocessing import Process
|
|
||||||
from multiprocessing.connection import Connection, Pipe
|
|
||||||
from pathlib import Path
|
|
||||||
from shutil import get_terminal_size
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import npyscreen
|
|
||||||
import torch
|
|
||||||
from npyscreen import widget
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType
|
|
||||||
from invokeai.backend.model_management import ModelManager, ModelType
|
|
||||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
from invokeai.frontend.install.widgets import (
|
|
||||||
MIN_COLS,
|
|
||||||
MIN_LINES,
|
|
||||||
BufferBox,
|
|
||||||
CenteredTitleText,
|
|
||||||
CyclingForm,
|
|
||||||
MultiSelectColumns,
|
|
||||||
SingleSelectColumns,
|
|
||||||
TextBox,
|
|
||||||
WindowTooSmallException,
|
|
||||||
select_stable_diffusion_config_file,
|
|
||||||
set_min_terminal_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
logger = InvokeAILogger.get_logger()
|
|
||||||
|
|
||||||
# build a table mapping all non-printable characters to None
|
|
||||||
# for stripping control characters
|
|
||||||
# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python
|
|
||||||
NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()}
|
|
||||||
|
|
||||||
# maximum number of installed models we can display before overflowing vertically
|
|
||||||
MAX_OTHER_MODELS = 72
|
|
||||||
|
|
||||||
|
|
||||||
def make_printable(s: str) -> str:
|
|
||||||
"""Replace non-printable characters in a string"""
|
|
||||||
return s.translate(NOPRINT_TRANS_TABLE)
|
|
||||||
|
|
||||||
|
|
||||||
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|
||||||
# for responsive resizing set to False, but this seems to cause a crash!
|
|
||||||
FIX_MINIMUM_SIZE_WHEN_CREATED = True
|
|
||||||
|
|
||||||
# for persistence
|
|
||||||
current_tab = 0
|
|
||||||
|
|
||||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
|
||||||
self.multipage = multipage
|
|
||||||
self.subprocess = None
|
|
||||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad?
|
|
||||||
|
|
||||||
def create(self):
|
|
||||||
self.keypress_timeout = 10
|
|
||||||
self.counter = 0
|
|
||||||
self.subprocess_connection = None
|
|
||||||
|
|
||||||
if not config.model_conf_path.exists():
|
|
||||||
with open(config.model_conf_path, "w") as file:
|
|
||||||
print("# InvokeAI model configuration file", file=file)
|
|
||||||
self.installer = ModelInstall(config)
|
|
||||||
self.all_models = self.installer.all_models()
|
|
||||||
self.starter_models = self.installer.starter_models()
|
|
||||||
self.model_labels = self._get_model_labels()
|
|
||||||
window_width, window_height = get_terminal_size()
|
|
||||||
|
|
||||||
self.nextrely -= 1
|
|
||||||
self.add_widget_intelligent(
|
|
||||||
npyscreen.FixedText,
|
|
||||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields. Cursor keys navigate, and <space> selects.",
|
|
||||||
editable=False,
|
|
||||||
color="CAUTION",
|
|
||||||
)
|
|
||||||
self.nextrely += 1
|
|
||||||
self.tabs = self.add_widget_intelligent(
|
|
||||||
SingleSelectColumns,
|
|
||||||
values=[
|
|
||||||
"STARTERS",
|
|
||||||
"MAINS",
|
|
||||||
"CONTROLNETS",
|
|
||||||
"T2I-ADAPTERS",
|
|
||||||
"IP-ADAPTERS",
|
|
||||||
"LORAS",
|
|
||||||
"TI EMBEDDINGS",
|
|
||||||
],
|
|
||||||
value=[self.current_tab],
|
|
||||||
columns=7,
|
|
||||||
max_height=2,
|
|
||||||
relx=8,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.tabs.on_changed = self._toggle_tables
|
|
||||||
|
|
||||||
top_of_table = self.nextrely
|
|
||||||
self.starter_pipelines = self.add_starter_pipelines()
|
|
||||||
bottom_of_table = self.nextrely
|
|
||||||
|
|
||||||
self.nextrely = top_of_table
|
|
||||||
self.pipeline_models = self.add_pipeline_widgets(
|
|
||||||
model_type=ModelType.Main, window_width=window_width, exclude=self.starter_models
|
|
||||||
)
|
|
||||||
# self.pipeline_models['autoload_pending'] = True
|
|
||||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
|
||||||
|
|
||||||
self.nextrely = top_of_table
|
|
||||||
self.controlnet_models = self.add_model_widgets(
|
|
||||||
model_type=ModelType.ControlNet,
|
|
||||||
window_width=window_width,
|
|
||||||
)
|
|
||||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
|
||||||
|
|
||||||
self.nextrely = top_of_table
|
|
||||||
self.t2i_models = self.add_model_widgets(
|
|
||||||
model_type=ModelType.T2IAdapter,
|
|
||||||
window_width=window_width,
|
|
||||||
)
|
|
||||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
|
||||||
self.nextrely = top_of_table
|
|
||||||
self.ipadapter_models = self.add_model_widgets(
|
|
||||||
model_type=ModelType.IPAdapter,
|
|
||||||
window_width=window_width,
|
|
||||||
)
|
|
||||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
|
||||||
|
|
||||||
self.nextrely = top_of_table
|
|
||||||
self.lora_models = self.add_model_widgets(
|
|
||||||
model_type=ModelType.Lora,
|
|
||||||
window_width=window_width,
|
|
||||||
)
|
|
||||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
|
||||||
|
|
||||||
self.nextrely = top_of_table
|
|
||||||
self.ti_models = self.add_model_widgets(
|
|
||||||
model_type=ModelType.TextualInversion,
|
|
||||||
window_width=window_width,
|
|
||||||
)
|
|
||||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
|
||||||
|
|
||||||
self.nextrely = bottom_of_table + 1
|
|
||||||
|
|
||||||
self.monitor = self.add_widget_intelligent(
|
|
||||||
BufferBox,
|
|
||||||
name="Log Messages",
|
|
||||||
editable=False,
|
|
||||||
max_height=6,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.nextrely += 1
|
|
||||||
done_label = "APPLY CHANGES"
|
|
||||||
back_label = "BACK"
|
|
||||||
cancel_label = "CANCEL"
|
|
||||||
current_position = self.nextrely
|
|
||||||
if self.multipage:
|
|
||||||
self.back_button = self.add_widget_intelligent(
|
|
||||||
npyscreen.ButtonPress,
|
|
||||||
name=back_label,
|
|
||||||
when_pressed_function=self.on_back,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.nextrely = current_position
|
|
||||||
self.cancel_button = self.add_widget_intelligent(
|
|
||||||
npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel
|
|
||||||
)
|
|
||||||
self.nextrely = current_position
|
|
||||||
self.ok_button = self.add_widget_intelligent(
|
|
||||||
npyscreen.ButtonPress,
|
|
||||||
name=done_label,
|
|
||||||
relx=(window_width - len(done_label)) // 2,
|
|
||||||
when_pressed_function=self.on_execute,
|
|
||||||
)
|
|
||||||
|
|
||||||
label = "APPLY CHANGES & EXIT"
|
|
||||||
self.nextrely = current_position
|
|
||||||
self.done = self.add_widget_intelligent(
|
|
||||||
npyscreen.ButtonPress,
|
|
||||||
name=label,
|
|
||||||
relx=window_width - len(label) - 15,
|
|
||||||
when_pressed_function=self.on_done,
|
|
||||||
)
|
|
||||||
|
|
||||||
# This restores the selected page on return from an installation
|
|
||||||
for _i in range(1, self.current_tab + 1):
|
|
||||||
self.tabs.h_cursor_line_down(1)
|
|
||||||
self._toggle_tables([self.current_tab])
|
|
||||||
|
|
||||||
############# diffusers tab ##########
|
|
||||||
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
|
||||||
"""Add widgets responsible for selecting diffusers models"""
|
|
||||||
widgets = {}
|
|
||||||
models = self.all_models
|
|
||||||
starters = self.starter_models
|
|
||||||
starter_model_labels = self.model_labels
|
|
||||||
|
|
||||||
self.installed_models = sorted([x for x in starters if models[x].installed])
|
|
||||||
|
|
||||||
widgets.update(
|
|
||||||
label1=self.add_widget_intelligent(
|
|
||||||
CenteredTitleText,
|
|
||||||
name="Select from a starter set of Stable Diffusion models from HuggingFace.",
|
|
||||||
editable=False,
|
|
||||||
labelColor="CAUTION",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.nextrely -= 1
|
|
||||||
# if user has already installed some initial models, then don't patronize them
|
|
||||||
# by showing more recommendations
|
|
||||||
show_recommended = len(self.installed_models) == 0
|
|
||||||
keys = [x for x in models.keys() if x in starters]
|
|
||||||
widgets.update(
|
|
||||||
models_selected=self.add_widget_intelligent(
|
|
||||||
MultiSelectColumns,
|
|
||||||
columns=1,
|
|
||||||
name="Install Starter Models",
|
|
||||||
values=[starter_model_labels[x] for x in keys],
|
|
||||||
value=[
|
|
||||||
keys.index(x)
|
|
||||||
for x in keys
|
|
||||||
if (show_recommended and models[x].recommended) or (x in self.installed_models)
|
|
||||||
],
|
|
||||||
max_height=len(starters) + 1,
|
|
||||||
relx=4,
|
|
||||||
scroll_exit=True,
|
|
||||||
),
|
|
||||||
models=keys,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.nextrely += 1
|
|
||||||
return widgets
|
|
||||||
|
|
||||||
############# Add a set of model install widgets ########
|
|
||||||
def add_model_widgets(
|
|
||||||
self,
|
|
||||||
model_type: ModelType,
|
|
||||||
window_width: int = 120,
|
|
||||||
install_prompt: str = None,
|
|
||||||
exclude: set = None,
|
|
||||||
) -> dict[str, npyscreen.widget]:
|
|
||||||
"""Generic code to create model selection widgets"""
|
|
||||||
if exclude is None:
|
|
||||||
exclude = set()
|
|
||||||
widgets = {}
|
|
||||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
|
||||||
model_labels = [self.model_labels[x] for x in model_list]
|
|
||||||
|
|
||||||
show_recommended = len(self.installed_models) == 0
|
|
||||||
truncated = False
|
|
||||||
if len(model_list) > 0:
|
|
||||||
max_width = max([len(x) for x in model_labels])
|
|
||||||
columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding
|
|
||||||
columns = min(len(model_list), columns) or 1
|
|
||||||
prompt = (
|
|
||||||
install_prompt
|
|
||||||
or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk."
|
|
||||||
)
|
|
||||||
|
|
||||||
widgets.update(
|
|
||||||
label1=self.add_widget_intelligent(
|
|
||||||
CenteredTitleText,
|
|
||||||
name=prompt,
|
|
||||||
editable=False,
|
|
||||||
labelColor="CAUTION",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(model_labels) > MAX_OTHER_MODELS:
|
|
||||||
model_labels = model_labels[0:MAX_OTHER_MODELS]
|
|
||||||
truncated = True
|
|
||||||
|
|
||||||
widgets.update(
|
|
||||||
models_selected=self.add_widget_intelligent(
|
|
||||||
MultiSelectColumns,
|
|
||||||
columns=columns,
|
|
||||||
name=f"Install {model_type} Models",
|
|
||||||
values=model_labels,
|
|
||||||
value=[
|
|
||||||
model_list.index(x)
|
|
||||||
for x in model_list
|
|
||||||
if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed
|
|
||||||
],
|
|
||||||
max_height=len(model_list) // columns + 1,
|
|
||||||
relx=4,
|
|
||||||
scroll_exit=True,
|
|
||||||
),
|
|
||||||
models=model_list,
|
|
||||||
)
|
|
||||||
|
|
||||||
if truncated:
|
|
||||||
widgets.update(
|
|
||||||
warning_message=self.add_widget_intelligent(
|
|
||||||
npyscreen.FixedText,
|
|
||||||
value=f"Too many models to display (max={MAX_OTHER_MODELS}). Some are not displayed.",
|
|
||||||
editable=False,
|
|
||||||
color="CAUTION",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.nextrely += 1
|
|
||||||
widgets.update(
|
|
||||||
download_ids=self.add_widget_intelligent(
|
|
||||||
TextBox,
|
|
||||||
name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
|
|
||||||
max_height=4,
|
|
||||||
scroll_exit=True,
|
|
||||||
editable=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return widgets
|
|
||||||
|
|
||||||
### Tab for arbitrary diffusers widgets ###
|
|
||||||
def add_pipeline_widgets(
|
|
||||||
self,
|
|
||||||
model_type: ModelType = ModelType.Main,
|
|
||||||
window_width: int = 120,
|
|
||||||
**kwargs,
|
|
||||||
) -> dict[str, npyscreen.widget]:
|
|
||||||
"""Similar to add_model_widgets() but adds some additional widgets at the bottom
|
|
||||||
to support the autoload directory"""
|
|
||||||
widgets = self.add_model_widgets(
|
|
||||||
model_type=model_type,
|
|
||||||
window_width=window_width,
|
|
||||||
install_prompt=f"Installed {model_type.value.title()} models. Unchecked models in the InvokeAI root directory will be deleted. Enter URLs, paths or repo_ids to import.",
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return widgets
|
|
||||||
|
|
||||||
def resize(self):
|
|
||||||
super().resize()
|
|
||||||
if s := self.starter_pipelines.get("models_selected"):
|
|
||||||
keys = [x for x in self.all_models.keys() if x in self.starter_models]
|
|
||||||
s.values = [self.model_labels[x] for x in keys]
|
|
||||||
|
|
||||||
def _toggle_tables(self, value=None):
|
|
||||||
selected_tab = value[0]
|
|
||||||
widgets = [
|
|
||||||
self.starter_pipelines,
|
|
||||||
self.pipeline_models,
|
|
||||||
self.controlnet_models,
|
|
||||||
self.t2i_models,
|
|
||||||
self.ipadapter_models,
|
|
||||||
self.lora_models,
|
|
||||||
self.ti_models,
|
|
||||||
]
|
|
||||||
|
|
||||||
for group in widgets:
|
|
||||||
for _k, v in group.items():
|
|
||||||
try:
|
|
||||||
v.hidden = True
|
|
||||||
v.editable = False
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
for _k, v in widgets[selected_tab].items():
|
|
||||||
try:
|
|
||||||
v.hidden = False
|
|
||||||
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
|
||||||
v.editable = True
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self.__class__.current_tab = selected_tab # for persistence
|
|
||||||
self.display()
|
|
||||||
|
|
||||||
def _get_model_labels(self) -> dict[str, str]:
|
|
||||||
window_width, window_height = get_terminal_size()
|
|
||||||
checkbox_width = 4
|
|
||||||
spacing_width = 2
|
|
||||||
|
|
||||||
models = self.all_models
|
|
||||||
label_width = max([len(models[x].name) for x in models])
|
|
||||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
for x in models.keys():
|
|
||||||
description = models[x].description
|
|
||||||
description = (
|
|
||||||
description[0 : description_width - 3] + "..."
|
|
||||||
if description and len(description) > description_width
|
|
||||||
else description
|
|
||||||
if description
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
result[x] = f"%-{label_width}s %s" % (models[x].name, description)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _get_columns(self) -> int:
|
|
||||||
window_width, window_height = get_terminal_size()
|
|
||||||
cols = 4 if window_width > 240 else 3 if window_width > 160 else 2 if window_width > 80 else 1
|
|
||||||
return min(cols, len(self.installed_models))
|
|
||||||
|
|
||||||
def confirm_deletions(self, selections: InstallSelections) -> bool:
|
|
||||||
remove_models = selections.remove_models
|
|
||||||
if len(remove_models) > 0:
|
|
||||||
mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models])
|
|
||||||
return npyscreen.notify_ok_cancel(
|
|
||||||
f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def on_execute(self):
|
|
||||||
self.marshall_arguments()
|
|
||||||
app = self.parentApp
|
|
||||||
if not self.confirm_deletions(app.install_selections):
|
|
||||||
return
|
|
||||||
|
|
||||||
self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True)
|
|
||||||
self.ok_button.hidden = True
|
|
||||||
self.display()
|
|
||||||
|
|
||||||
# TO DO: Spawn a worker thread, not a subprocess
|
|
||||||
parent_conn, child_conn = Pipe()
|
|
||||||
p = Process(
|
|
||||||
target=process_and_execute,
|
|
||||||
kwargs={
|
|
||||||
"opt": app.program_opts,
|
|
||||||
"selections": app.install_selections,
|
|
||||||
"conn_out": child_conn,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
p.start()
|
|
||||||
child_conn.close()
|
|
||||||
self.subprocess_connection = parent_conn
|
|
||||||
self.subprocess = p
|
|
||||||
app.install_selections = InstallSelections()
|
|
||||||
|
|
||||||
def on_back(self):
|
|
||||||
self.parentApp.switchFormPrevious()
|
|
||||||
self.editing = False
|
|
||||||
|
|
||||||
def on_cancel(self):
|
|
||||||
self.parentApp.setNextForm(None)
|
|
||||||
self.parentApp.user_cancelled = True
|
|
||||||
self.editing = False
|
|
||||||
|
|
||||||
def on_done(self):
|
|
||||||
self.marshall_arguments()
|
|
||||||
if not self.confirm_deletions(self.parentApp.install_selections):
|
|
||||||
return
|
|
||||||
self.parentApp.setNextForm(None)
|
|
||||||
self.parentApp.user_cancelled = False
|
|
||||||
self.editing = False
|
|
||||||
|
|
||||||
########## This routine monitors the child process that is performing model installation and removal #####
|
|
||||||
def while_waiting(self):
|
|
||||||
"""Called during idle periods. Main task is to update the Log Messages box with messages
|
|
||||||
from the child process that does the actual installation/removal"""
|
|
||||||
c = self.subprocess_connection
|
|
||||||
if not c:
|
|
||||||
return
|
|
||||||
|
|
||||||
monitor_widget = self.monitor.entry_widget
|
|
||||||
while c.poll():
|
|
||||||
try:
|
|
||||||
data = c.recv_bytes().decode("utf-8")
|
|
||||||
data.strip("\n")
|
|
||||||
|
|
||||||
# processing child is requesting user input to select the
|
|
||||||
# right configuration file
|
|
||||||
if data.startswith("*need v2 config"):
|
|
||||||
_, model_path, *_ = data.split(":", 2)
|
|
||||||
self._return_v2_config(model_path)
|
|
||||||
|
|
||||||
# processing child is done
|
|
||||||
elif data == "*done*":
|
|
||||||
self._close_subprocess_and_regenerate_form()
|
|
||||||
break
|
|
||||||
|
|
||||||
# update the log message box
|
|
||||||
else:
|
|
||||||
data = make_printable(data)
|
|
||||||
data = data.replace("[A", "")
|
|
||||||
monitor_widget.buffer(
|
|
||||||
textwrap.wrap(
|
|
||||||
data,
|
|
||||||
width=monitor_widget.width,
|
|
||||||
subsequent_indent=" ",
|
|
||||||
),
|
|
||||||
scroll_end=True,
|
|
||||||
)
|
|
||||||
self.display()
|
|
||||||
except (EOFError, OSError):
|
|
||||||
self.subprocess_connection = None
|
|
||||||
|
|
||||||
def _return_v2_config(self, model_path: str):
|
|
||||||
c = self.subprocess_connection
|
|
||||||
model_name = Path(model_path).name
|
|
||||||
message = select_stable_diffusion_config_file(model_name=model_name)
|
|
||||||
c.send_bytes(message.encode("utf-8"))
|
|
||||||
|
|
||||||
def _close_subprocess_and_regenerate_form(self):
|
|
||||||
app = self.parentApp
|
|
||||||
self.subprocess_connection.close()
|
|
||||||
self.subprocess_connection = None
|
|
||||||
self.monitor.entry_widget.buffer(["** Action Complete **"])
|
|
||||||
self.display()
|
|
||||||
|
|
||||||
# rebuild the form, saving and restoring some of the fields that need to be preserved.
|
|
||||||
saved_messages = self.monitor.entry_widget.values
|
|
||||||
|
|
||||||
app.main_form = app.addForm(
|
|
||||||
"MAIN",
|
|
||||||
addModelsForm,
|
|
||||||
name="Install Stable Diffusion Models",
|
|
||||||
multipage=self.multipage,
|
|
||||||
)
|
|
||||||
app.switchForm("MAIN")
|
|
||||||
|
|
||||||
app.main_form.monitor.entry_widget.values = saved_messages
|
|
||||||
app.main_form.monitor.entry_widget.buffer([""], scroll_end=True)
|
|
||||||
# app.main_form.pipeline_models['autoload_directory'].value = autoload_dir
|
|
||||||
# app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan
|
|
||||||
|
|
||||||
def marshall_arguments(self):
|
|
||||||
"""
|
|
||||||
Assemble arguments and store as attributes of the application:
|
|
||||||
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
|
|
||||||
True => Install
|
|
||||||
False => Remove
|
|
||||||
.scan_directory: Path to a directory of models to scan and import
|
|
||||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
|
||||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
|
||||||
"""
|
|
||||||
selections = self.parentApp.install_selections
|
|
||||||
all_models = self.all_models
|
|
||||||
|
|
||||||
# Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove
|
|
||||||
ui_sections = [
|
|
||||||
self.starter_pipelines,
|
|
||||||
self.pipeline_models,
|
|
||||||
self.controlnet_models,
|
|
||||||
self.t2i_models,
|
|
||||||
self.ipadapter_models,
|
|
||||||
self.lora_models,
|
|
||||||
self.ti_models,
|
|
||||||
]
|
|
||||||
for section in ui_sections:
|
|
||||||
if "models_selected" not in section:
|
|
||||||
continue
|
|
||||||
selected = {section["models"][x] for x in section["models_selected"].value}
|
|
||||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
|
||||||
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
|
||||||
selections.remove_models.extend(models_to_remove)
|
|
||||||
selections.install_models.extend(
|
|
||||||
all_models[x].path or all_models[x].repo_id
|
|
||||||
for x in models_to_install
|
|
||||||
if all_models[x].path or all_models[x].repo_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# models located in the 'download_ids" section
|
|
||||||
for section in ui_sections:
|
|
||||||
if downloads := section.get("download_ids"):
|
|
||||||
selections.install_models.extend(downloads.value.split())
|
|
||||||
|
|
||||||
# NOT NEEDED - DONE IN BACKEND NOW
|
|
||||||
# # special case for the ipadapter_models. If any of the adapters are
|
|
||||||
# # chosen, then we add the corresponding encoder(s) to the install list.
|
|
||||||
# section = self.ipadapter_models
|
|
||||||
# if section.get("models_selected"):
|
|
||||||
# selected_adapters = [
|
|
||||||
# self.all_models[section["models"][x]].name for x in section.get("models_selected").value
|
|
||||||
# ]
|
|
||||||
# encoders = []
|
|
||||||
# if any(["sdxl" in x for x in selected_adapters]):
|
|
||||||
# encoders.append("ip_adapter_sdxl_image_encoder")
|
|
||||||
# if any(["sd15" in x for x in selected_adapters]):
|
|
||||||
# encoders.append("ip_adapter_sd_image_encoder")
|
|
||||||
# for encoder in encoders:
|
|
||||||
# key = f"any/clip_vision/{encoder}"
|
|
||||||
# repo_id = f"InvokeAI/{encoder}"
|
|
||||||
# if key not in self.all_models:
|
|
||||||
# selections.install_models.append(repo_id)
|
|
||||||
|
|
||||||
|
|
||||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
|
||||||
def __init__(self, opt):
|
|
||||||
super().__init__()
|
|
||||||
self.program_opts = opt
|
|
||||||
self.user_cancelled = False
|
|
||||||
# self.autoload_pending = True
|
|
||||||
self.install_selections = InstallSelections()
|
|
||||||
|
|
||||||
def onStart(self):
|
|
||||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
|
||||||
self.main_form = self.addForm(
|
|
||||||
"MAIN",
|
|
||||||
addModelsForm,
|
|
||||||
name="Install Stable Diffusion Models",
|
|
||||||
cycle_widgets=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StderrToMessage:
|
|
||||||
def __init__(self, connection: Connection):
|
|
||||||
self.connection = connection
|
|
||||||
|
|
||||||
def write(self, data: str):
|
|
||||||
self.connection.send_bytes(data.encode("utf-8"))
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------
|
|
||||||
def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType:
|
|
||||||
if tui_conn:
|
|
||||||
logger.debug("Waiting for user response...")
|
|
||||||
return _ask_user_for_pt_tui(model_path, tui_conn)
|
|
||||||
else:
|
|
||||||
return _ask_user_for_pt_cmdline(model_path)
|
|
||||||
|
|
||||||
|
|
||||||
def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]:
|
|
||||||
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
|
|
||||||
print(
|
|
||||||
f"""
|
|
||||||
Please select the scheduler prediction type of the checkpoint named {model_path.name}:
|
|
||||||
[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images
|
|
||||||
[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models
|
|
||||||
[3] Accept the best guess; you can fix it in the Web UI later
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
choice = None
|
|
||||||
ok = False
|
|
||||||
while not ok:
|
|
||||||
try:
|
|
||||||
choice = input("select [3]> ").strip()
|
|
||||||
if not choice:
|
|
||||||
return None
|
|
||||||
choice = choices[int(choice) - 1]
|
|
||||||
ok = True
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
print(f"{choice} is not a valid choice")
|
|
||||||
except EOFError:
|
|
||||||
return
|
|
||||||
return choice
|
|
||||||
|
|
||||||
|
|
||||||
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType:
|
|
||||||
tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8"))
|
|
||||||
# note that we don't do any status checking here
|
|
||||||
response = tui_conn.recv_bytes().decode("utf-8")
|
|
||||||
if response is None:
|
|
||||||
return None
|
|
||||||
elif response == "epsilon":
|
|
||||||
return SchedulerPredictionType.epsilon
|
|
||||||
elif response == "v":
|
|
||||||
return SchedulerPredictionType.VPrediction
|
|
||||||
elif response == "guess":
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------
|
|
||||||
def process_and_execute(
|
|
||||||
opt: Namespace,
|
|
||||||
selections: InstallSelections,
|
|
||||||
conn_out: Connection = None,
|
|
||||||
):
|
|
||||||
# need to reinitialize config in subprocess
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
args = ["--root", opt.root] if opt.root else []
|
|
||||||
config.parse_args(args)
|
|
||||||
|
|
||||||
# set up so that stderr is sent to conn_out
|
|
||||||
if conn_out:
|
|
||||||
translator = StderrToMessage(conn_out)
|
|
||||||
sys.stderr = translator
|
|
||||||
sys.stdout = translator
|
|
||||||
logger = InvokeAILogger.get_logger()
|
|
||||||
logger.handlers.clear()
|
|
||||||
logger.addHandler(logging.StreamHandler(translator))
|
|
||||||
|
|
||||||
installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out))
|
|
||||||
installer.install(selections)
|
|
||||||
|
|
||||||
if conn_out:
|
|
||||||
conn_out.send_bytes("*done*".encode("utf-8"))
|
|
||||||
conn_out.close()
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------
|
|
||||||
def select_and_download_models(opt: Namespace):
|
|
||||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
|
||||||
config.precision = precision
|
|
||||||
installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type)
|
|
||||||
if opt.list_models:
|
|
||||||
installer.list_models(opt.list_models)
|
|
||||||
elif opt.add or opt.delete:
|
|
||||||
selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or [])
|
|
||||||
installer.install(selections)
|
|
||||||
elif opt.default_only:
|
|
||||||
selections = InstallSelections(install_models=installer.default_model())
|
|
||||||
installer.install(selections)
|
|
||||||
elif opt.yes_to_all:
|
|
||||||
selections = InstallSelections(install_models=installer.recommended_models())
|
|
||||||
installer.install(selections)
|
|
||||||
|
|
||||||
# this is where the TUI is called
|
|
||||||
else:
|
|
||||||
# needed to support the probe() method running under a subprocess
|
|
||||||
torch.multiprocessing.set_start_method("spawn")
|
|
||||||
|
|
||||||
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
|
|
||||||
raise WindowTooSmallException(
|
|
||||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
|
||||||
)
|
|
||||||
|
|
||||||
installApp = AddModelApplication(opt)
|
|
||||||
try:
|
|
||||||
installApp.run()
|
|
||||||
except KeyboardInterrupt as e:
|
|
||||||
if hasattr(installApp, "main_form"):
|
|
||||||
if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive():
|
|
||||||
logger.info("Terminating subprocesses")
|
|
||||||
installApp.main_form.subprocess.terminate()
|
|
||||||
installApp.main_form.subprocess = None
|
|
||||||
raise e
|
|
||||||
process_and_execute(opt, installApp.install_selections)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
|
||||||
parser.add_argument(
|
|
||||||
"--add",
|
|
||||||
nargs="*",
|
|
||||||
help="List of URLs, local paths or repo_ids of models to install",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--delete",
|
|
||||||
nargs="*",
|
|
||||||
help="List of names of models to idelete",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--full-precision",
|
|
||||||
dest="full_precision",
|
|
||||||
action=argparse.BooleanOptionalAction,
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help="use 32-bit weights instead of faster 16-bit weights",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--yes",
|
|
||||||
"-y",
|
|
||||||
dest="yes_to_all",
|
|
||||||
action="store_true",
|
|
||||||
help='answer "yes" to all prompts',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--default_only",
|
|
||||||
action="store_true",
|
|
||||||
help="Only install the default model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--list-models",
|
|
||||||
choices=[x.value for x in ModelType],
|
|
||||||
help="list installed models",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--config_file",
|
|
||||||
"-c",
|
|
||||||
dest="config_file",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="path to configuration file to create",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--root_dir",
|
|
||||||
dest="root",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="path to root of install directory",
|
|
||||||
)
|
|
||||||
opt = parser.parse_args()
|
|
||||||
|
|
||||||
invoke_args = []
|
|
||||||
if opt.root:
|
|
||||||
invoke_args.extend(["--root", opt.root])
|
|
||||||
if opt.full_precision:
|
|
||||||
invoke_args.extend(["--precision", "float32"])
|
|
||||||
config.parse_args(invoke_args)
|
|
||||||
logger = InvokeAILogger().get_logger(config=config)
|
|
||||||
|
|
||||||
if not config.model_conf_path.exists():
|
|
||||||
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
|
|
||||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
|
||||||
|
|
||||||
invokeai_configure()
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
try:
|
|
||||||
select_and_download_models(opt)
|
|
||||||
except AssertionError as e:
|
|
||||||
logger.error(e)
|
|
||||||
sys.exit(-1)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
curses.nocbreak()
|
|
||||||
curses.echo()
|
|
||||||
curses.endwin()
|
|
||||||
logger.info("Goodbye! Come back soon.")
|
|
||||||
except WindowTooSmallException as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
except widget.NotEnoughSpaceForWidget as e:
|
|
||||||
if str(e).startswith("Height of 1 allocated"):
|
|
||||||
logger.error("Insufficient vertical space for the interface. Please make your window taller and try again")
|
|
||||||
input("Press any key to continue...")
|
|
||||||
except Exception as e:
|
|
||||||
if str(e).startswith("addwstr"):
|
|
||||||
logger.error(
|
|
||||||
"Insufficient horizontal space for the interface. Please make your window wider and try again."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(f"An exception has occurred: {str(e)} Details:")
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
input("Press any key to continue...")
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,438 +0,0 @@
|
|||||||
"""
|
|
||||||
invokeai.frontend.merge exports a single function called merge_diffusion_models().
|
|
||||||
|
|
||||||
It merges 2-3 models together and create a new InvokeAI-registered diffusion model.
|
|
||||||
|
|
||||||
Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import curses
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from argparse import Namespace
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import npyscreen
|
|
||||||
from npyscreen import widget
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
|
||||||
from invokeai.backend.install.install_helper import initialize_installer
|
|
||||||
from invokeai.backend.model_manager import (
|
|
||||||
BaseModelType,
|
|
||||||
ModelFormat,
|
|
||||||
ModelType,
|
|
||||||
ModelVariantType,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_manager.merge import ModelMerger
|
|
||||||
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
|
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
|
|
||||||
BASE_TYPES = [
|
|
||||||
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
|
|
||||||
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
|
|
||||||
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_args() -> Namespace:
|
|
||||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
|
||||||
parser.add_argument(
|
|
||||||
"--root_dir",
|
|
||||||
type=Path,
|
|
||||||
default=config.root,
|
|
||||||
help="Path to the invokeai runtime directory",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--front_end",
|
|
||||||
"--gui",
|
|
||||||
dest="front_end",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--models",
|
|
||||||
dest="model_names",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
help="Two to three model names to be merged",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--base_model",
|
|
||||||
type=str,
|
|
||||||
choices=[x[0].value for x in BASE_TYPES],
|
|
||||||
help="The base model shared by the models to be merged",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--merged_model_name",
|
|
||||||
"--destination",
|
|
||||||
dest="merged_model_name",
|
|
||||||
type=str,
|
|
||||||
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--alpha",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--interpolation",
|
|
||||||
dest="interp",
|
|
||||||
type=str,
|
|
||||||
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
|
|
||||||
default="weighted_sum",
|
|
||||||
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--force",
|
|
||||||
action="store_true",
|
|
||||||
help="Try to merge models even if they are incompatible with each other",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--clobber",
|
|
||||||
"--overwrite",
|
|
||||||
dest="clobber",
|
|
||||||
action="store_true",
|
|
||||||
help="Overwrite the merged model if --merged_model_name already exists",
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------- GUI HERE -------------------------
|
|
||||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|
||||||
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
|
|
||||||
|
|
||||||
def __init__(self, parentApp, name):
|
|
||||||
self.parentApp = parentApp
|
|
||||||
self.ALLOW_RESIZE = True
|
|
||||||
self.FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
|
||||||
super().__init__(parentApp, name)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model_record_store(self) -> ModelRecordServiceBase:
|
|
||||||
installer: ModelInstallServiceBase = self.parentApp.installer
|
|
||||||
return installer.record_store
|
|
||||||
|
|
||||||
def afterEditing(self) -> None:
|
|
||||||
self.parentApp.setNextForm(None)
|
|
||||||
|
|
||||||
def create(self) -> None:
|
|
||||||
window_height, window_width = curses.initscr().getmaxyx()
|
|
||||||
self.current_base = 0
|
|
||||||
self.models = self.get_models(BASE_TYPES[self.current_base][0])
|
|
||||||
self.model_names = [x[1] for x in self.models]
|
|
||||||
max_width = max([len(x) for x in self.model_names])
|
|
||||||
max_width += 6
|
|
||||||
horizontal_layout = max_width * 3 < window_width
|
|
||||||
|
|
||||||
self.add_widget_intelligent(
|
|
||||||
npyscreen.FixedText,
|
|
||||||
color="CONTROL",
|
|
||||||
value="Select two models to merge and optionally a third.",
|
|
||||||
editable=False,
|
|
||||||
)
|
|
||||||
self.add_widget_intelligent(
|
|
||||||
npyscreen.FixedText,
|
|
||||||
color="CONTROL",
|
|
||||||
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
|
||||||
editable=False,
|
|
||||||
)
|
|
||||||
self.nextrely += 1
|
|
||||||
self.base_select = self.add_widget_intelligent(
|
|
||||||
SingleSelectColumns,
|
|
||||||
values=[x[1] for x in BASE_TYPES],
|
|
||||||
value=[self.current_base],
|
|
||||||
columns=4,
|
|
||||||
max_height=2,
|
|
||||||
relx=8,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.base_select.on_changed = self._populate_models
|
|
||||||
self.add_widget_intelligent(
|
|
||||||
npyscreen.FixedText,
|
|
||||||
value="MODEL 1",
|
|
||||||
color="GOOD",
|
|
||||||
editable=False,
|
|
||||||
rely=6 if horizontal_layout else None,
|
|
||||||
)
|
|
||||||
self.model1 = self.add_widget_intelligent(
|
|
||||||
npyscreen.SelectOne,
|
|
||||||
values=self.model_names,
|
|
||||||
value=0,
|
|
||||||
max_height=len(self.model_names),
|
|
||||||
max_width=max_width,
|
|
||||||
scroll_exit=True,
|
|
||||||
rely=7,
|
|
||||||
)
|
|
||||||
self.add_widget_intelligent(
|
|
||||||
npyscreen.FixedText,
|
|
||||||
value="MODEL 2",
|
|
||||||
color="GOOD",
|
|
||||||
editable=False,
|
|
||||||
relx=max_width + 3 if horizontal_layout else None,
|
|
||||||
rely=6 if horizontal_layout else None,
|
|
||||||
)
|
|
||||||
self.model2 = self.add_widget_intelligent(
|
|
||||||
npyscreen.SelectOne,
|
|
||||||
name="(2)",
|
|
||||||
values=self.model_names,
|
|
||||||
value=1,
|
|
||||||
max_height=len(self.model_names),
|
|
||||||
max_width=max_width,
|
|
||||||
relx=max_width + 3 if horizontal_layout else None,
|
|
||||||
rely=7 if horizontal_layout else None,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.add_widget_intelligent(
|
|
||||||
npyscreen.FixedText,
|
|
||||||
value="MODEL 3",
|
|
||||||
color="GOOD",
|
|
||||||
editable=False,
|
|
||||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
|
||||||
rely=6 if horizontal_layout else None,
|
|
||||||
)
|
|
||||||
models_plus_none = self.model_names.copy()
|
|
||||||
models_plus_none.insert(0, "None")
|
|
||||||
self.model3 = self.add_widget_intelligent(
|
|
||||||
npyscreen.SelectOne,
|
|
||||||
name="(3)",
|
|
||||||
values=models_plus_none,
|
|
||||||
value=0,
|
|
||||||
max_height=len(self.model_names) + 1,
|
|
||||||
max_width=max_width,
|
|
||||||
scroll_exit=True,
|
|
||||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
|
||||||
rely=7 if horizontal_layout else None,
|
|
||||||
)
|
|
||||||
for m in [self.model1, self.model2, self.model3]:
|
|
||||||
m.when_value_edited = self.models_changed
|
|
||||||
self.merged_model_name = self.add_widget_intelligent(
|
|
||||||
TextBox,
|
|
||||||
name="Name for merged model:",
|
|
||||||
labelColor="CONTROL",
|
|
||||||
max_height=3,
|
|
||||||
value="",
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.force = self.add_widget_intelligent(
|
|
||||||
npyscreen.Checkbox,
|
|
||||||
name="Force merge of models created by different diffusers library versions",
|
|
||||||
labelColor="CONTROL",
|
|
||||||
value=True,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.nextrely += 1
|
|
||||||
self.merge_method = self.add_widget_intelligent(
|
|
||||||
npyscreen.TitleSelectOne,
|
|
||||||
name="Merge Method:",
|
|
||||||
values=self.interpolations,
|
|
||||||
value=0,
|
|
||||||
labelColor="CONTROL",
|
|
||||||
max_height=len(self.interpolations) + 1,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.alpha = self.add_widget_intelligent(
|
|
||||||
FloatTitleSlider,
|
|
||||||
name="Weight (alpha) to assign to second and third models:",
|
|
||||||
out_of=1.0,
|
|
||||||
step=0.01,
|
|
||||||
lowest=0,
|
|
||||||
value=0.5,
|
|
||||||
labelColor="CONTROL",
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.model1.editing = True
|
|
||||||
|
|
||||||
def models_changed(self) -> None:
|
|
||||||
models = self.model1.values
|
|
||||||
selected_model1 = self.model1.value[0]
|
|
||||||
selected_model2 = self.model2.value[0]
|
|
||||||
selected_model3 = self.model3.value[0]
|
|
||||||
merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}"
|
|
||||||
self.merged_model_name.value = merged_model_name
|
|
||||||
|
|
||||||
if selected_model3 > 0:
|
|
||||||
self.merge_method.values = ["add_difference ( A+(B-C) )"]
|
|
||||||
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
|
|
||||||
else:
|
|
||||||
self.merge_method.values = self.interpolations
|
|
||||||
self.merge_method.value = 0
|
|
||||||
|
|
||||||
def on_ok(self) -> None:
|
|
||||||
if self.validate_field_values() and self.check_for_overwrite():
|
|
||||||
self.parentApp.setNextForm(None)
|
|
||||||
self.editing = False
|
|
||||||
self.parentApp.merge_arguments = self.marshall_arguments()
|
|
||||||
npyscreen.notify("Starting the merge...")
|
|
||||||
else:
|
|
||||||
self.editing = True
|
|
||||||
|
|
||||||
def on_cancel(self) -> None:
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
def marshall_arguments(self) -> dict:
|
|
||||||
model_keys = [x[0] for x in self.models]
|
|
||||||
models = [
|
|
||||||
model_keys[self.model1.value[0]],
|
|
||||||
model_keys[self.model2.value[0]],
|
|
||||||
]
|
|
||||||
if self.model3.value[0] > 0:
|
|
||||||
models.append(model_keys[self.model3.value[0] - 1])
|
|
||||||
interp = "add_difference"
|
|
||||||
else:
|
|
||||||
interp = self.interpolations[self.merge_method.value[0]]
|
|
||||||
|
|
||||||
args = {
|
|
||||||
"model_keys": models,
|
|
||||||
"alpha": self.alpha.value,
|
|
||||||
"interp": interp,
|
|
||||||
"force": self.force.value,
|
|
||||||
"merged_model_name": self.merged_model_name.value,
|
|
||||||
}
|
|
||||||
return args
|
|
||||||
|
|
||||||
def check_for_overwrite(self) -> bool:
|
|
||||||
model_out = self.merged_model_name.value
|
|
||||||
if model_out not in self.model_names:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
result: bool = npyscreen.notify_yes_no(
|
|
||||||
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def validate_field_values(self) -> bool:
|
|
||||||
bad_fields = []
|
|
||||||
model_names = self.model_names
|
|
||||||
selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]}
|
|
||||||
if self.model3.value[0] > 0:
|
|
||||||
selected_models.add(model_names[self.model3.value[0] - 1])
|
|
||||||
if len(selected_models) < 2:
|
|
||||||
bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}")
|
|
||||||
if len(bad_fields) > 0:
|
|
||||||
message = "The following problems were detected and must be corrected:"
|
|
||||||
for problem in bad_fields:
|
|
||||||
message += f"\n* {problem}"
|
|
||||||
npyscreen.notify_confirm(message)
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
|
|
||||||
models = [
|
|
||||||
(x.key, x.name)
|
|
||||||
for x in self.model_record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
|
|
||||||
if x.format == ModelFormat("diffusers")
|
|
||||||
and hasattr(x, "variant")
|
|
||||||
and x.variant == ModelVariantType("normal")
|
|
||||||
]
|
|
||||||
return sorted(models, key=lambda x: x[1])
|
|
||||||
|
|
||||||
def _populate_models(self, value: List[int]) -> None:
|
|
||||||
base_model = BASE_TYPES[value[0]][0]
|
|
||||||
self.models = self.get_models(base_model)
|
|
||||||
self.model_names = [x[1] for x in self.models]
|
|
||||||
|
|
||||||
models_plus_none = self.model_names.copy()
|
|
||||||
models_plus_none.insert(0, "None")
|
|
||||||
self.model1.values = self.model_names
|
|
||||||
self.model2.values = self.model_names
|
|
||||||
self.model3.values = models_plus_none
|
|
||||||
|
|
||||||
self.display()
|
|
||||||
|
|
||||||
|
|
||||||
# npyscreen is untyped and causes mypy to get naggy
|
|
||||||
class Mergeapp(npyscreen.NPSAppManaged): # type: ignore
|
|
||||||
def __init__(self, installer: ModelInstallServiceBase):
|
|
||||||
"""Initialize the npyscreen application."""
|
|
||||||
super().__init__()
|
|
||||||
self.installer = installer
|
|
||||||
|
|
||||||
def onStart(self) -> None:
|
|
||||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
|
||||||
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
|
||||||
|
|
||||||
|
|
||||||
def run_gui(args: Namespace) -> None:
|
|
||||||
installer = initialize_installer(config)
|
|
||||||
mergeapp = Mergeapp(installer)
|
|
||||||
mergeapp.run()
|
|
||||||
merge_args = mergeapp.merge_arguments
|
|
||||||
merger = ModelMerger(installer)
|
|
||||||
merger.merge_diffusion_models_and_save(**merge_args)
|
|
||||||
logger.info(f'Models merged into new model: "{merge_args.merged_model_name}".')
|
|
||||||
|
|
||||||
|
|
||||||
def run_cli(args: Namespace) -> None:
|
|
||||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
|
||||||
assert (
|
|
||||||
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
|
|
||||||
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
|
||||||
|
|
||||||
if not args.merged_model_name:
|
|
||||||
args.merged_model_name = "+".join(args.model_names)
|
|
||||||
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
|
|
||||||
|
|
||||||
installer = initialize_installer(config)
|
|
||||||
store = installer.record_store
|
|
||||||
assert (
|
|
||||||
len(store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
|
|
||||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
|
||||||
|
|
||||||
merger = ModelMerger(installer)
|
|
||||||
model_keys = []
|
|
||||||
for name in args.model_names:
|
|
||||||
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
|
|
||||||
model_keys.append(name)
|
|
||||||
else:
|
|
||||||
models = store.search_by_attr(
|
|
||||||
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
|
|
||||||
)
|
|
||||||
assert len(models) > 0, f"{name}: Unknown model"
|
|
||||||
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
|
|
||||||
model_keys.append(models[0].key)
|
|
||||||
|
|
||||||
merger.merge_diffusion_models_and_save(
|
|
||||||
alpha=args.alpha,
|
|
||||||
model_keys=model_keys,
|
|
||||||
merged_model_name=args.merged_model_name,
|
|
||||||
interp=args.interp,
|
|
||||||
force=args.force,
|
|
||||||
)
|
|
||||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
args = _parse_args()
|
|
||||||
if args.root_dir:
|
|
||||||
config.parse_args(["--root", str(args.root_dir)])
|
|
||||||
else:
|
|
||||||
config.parse_args([])
|
|
||||||
|
|
||||||
try:
|
|
||||||
if args.front_end:
|
|
||||||
run_gui(args)
|
|
||||||
else:
|
|
||||||
run_cli(args)
|
|
||||||
except widget.NotEnoughSpaceForWidget as e:
|
|
||||||
if str(e).startswith("Height of 1 allocated"):
|
|
||||||
logger.error("You need to have at least two diffusers models defined in models.yaml in order to merge")
|
|
||||||
else:
|
|
||||||
logger.error("Not enough room for the user interface. Try making this window larger.")
|
|
||||||
sys.exit(-1)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
sys.exit(-1)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -20,7 +20,7 @@ from invokeai.app.services.model_install import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.services.model_records import UnknownModelException
|
from invokeai.app.services.model_records import UnknownModelException
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
||||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
OS = platform.uname().system
|
OS = platform.uname().system
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import BaseMetadata
|
from invokeai.backend.model_manager.metadata import BaseMetadata
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
|
||||||
from invokeai.backend.util.test_utils import install_and_load_model
|
from invokeai.backend.util.test_utils import install_and_load_model
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,17 +5,16 @@ Test model loading
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||||
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
|
from invokeai.app.services.model_load import ModelLoadServiceBase
|
||||||
from invokeai.backend.model_manager.load import AnyModelLoader
|
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
|
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase, embedding_file: Path):
|
||||||
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: AnyModelLoader, embedding_file: Path):
|
|
||||||
store = mm2_installer.record_store
|
store = mm2_installer.record_store
|
||||||
matches = store.search_by_attr(model_name="test_embedding")
|
matches = store.search_by_attr(model_name="test_embedding")
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
key = mm2_installer.register_path(embedding_file)
|
key = mm2_installer.register_path(embedding_file)
|
||||||
loaded_model = mm2_loader.load_model(store.get_model(key))
|
loaded_model = mm2_loader.load_model_by_config(store.get_model(key))
|
||||||
assert loaded_model is not None
|
assert loaded_model is not None
|
||||||
assert loaded_model.config.key == key
|
assert loaded_model.config.key == key
|
||||||
with loaded_model as model:
|
with loaded_model as model:
|
@ -6,24 +6,27 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pytest import FixtureRequest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
from requests_testadapter import TestAdapter, TestSession
|
from requests_testadapter import TestAdapter, TestSession
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.download import DownloadQueueService
|
from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
|
from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService
|
||||||
|
from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService
|
||||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
|
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||||
RepoCivitaiModelMetadata1,
|
RepoCivitaiModelMetadata1,
|
||||||
RepoCivitaiVersionMetadata1,
|
RepoCivitaiVersionMetadata1,
|
||||||
RepoHFMetadata1,
|
RepoHFMetadata1,
|
||||||
@ -86,22 +89,71 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
|||||||
app_config = InvokeAIAppConfig(
|
app_config = InvokeAIAppConfig(
|
||||||
root=mm2_root_dir,
|
root=mm2_root_dir,
|
||||||
models_dir=mm2_root_dir / "models",
|
models_dir=mm2_root_dir / "models",
|
||||||
|
log_level="info",
|
||||||
)
|
)
|
||||||
return app_config
|
return app_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceSQL) -> AnyModelLoader:
|
def mm2_download_queue(mm2_session: Session,
|
||||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
request: FixtureRequest
|
||||||
|
) -> DownloadQueueServiceBase:
|
||||||
|
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||||
|
download_queue.start()
|
||||||
|
|
||||||
|
def stop_queue() -> None:
|
||||||
|
download_queue.stop()
|
||||||
|
|
||||||
|
request.addfinalizer(stop_queue)
|
||||||
|
return download_queue
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
||||||
|
return mm2_record_store.metadata_store
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||||
ram_cache = ModelCache(
|
ram_cache = ModelCache(
|
||||||
logger=logger, max_cache_size=mm2_app_config.ram_cache_size, max_vram_cache_size=mm2_app_config.vram_cache_size
|
logger=InvokeAILogger.get_logger(),
|
||||||
|
max_cache_size=mm2_app_config.ram_cache_size,
|
||||||
|
max_vram_cache_size=mm2_app_config.vram_cache_size
|
||||||
)
|
)
|
||||||
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
|
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
|
||||||
return AnyModelLoader(app_config=mm2_app_config, logger=logger, ram_cache=ram_cache, convert_cache=convert_cache)
|
return ModelLoadService(app_config=mm2_app_config,
|
||||||
|
record_store=mm2_record_store,
|
||||||
|
ram_cache=ram_cache,
|
||||||
|
convert_cache=convert_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mm2_installer(mm2_app_config: InvokeAIAppConfig,
|
||||||
|
mm2_download_queue: DownloadQueueServiceBase,
|
||||||
|
mm2_session: Session,
|
||||||
|
request: FixtureRequest,
|
||||||
|
) -> ModelInstallServiceBase:
|
||||||
|
logger = InvokeAILogger.get_logger()
|
||||||
|
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||||
|
events = DummyEventService()
|
||||||
|
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||||
|
|
||||||
|
installer = ModelInstallService(
|
||||||
|
app_config=mm2_app_config,
|
||||||
|
record_store=store,
|
||||||
|
download_queue=mm2_download_queue,
|
||||||
|
event_bus=events,
|
||||||
|
session=mm2_session,
|
||||||
|
)
|
||||||
|
installer.start()
|
||||||
|
|
||||||
|
def stop_installer() -> None:
|
||||||
|
installer.stop()
|
||||||
|
|
||||||
|
request.addfinalizer(stop_installer)
|
||||||
|
return installer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||||
@ -161,11 +213,15 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL
|
|||||||
store.add_model("test_config_5", raw5)
|
store.add_model("test_config_5", raw5)
|
||||||
return store
|
return store
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
def mm2_model_manager(mm2_record_store: ModelRecordServiceBase,
|
||||||
return mm2_record_store.metadata_store
|
mm2_installer: ModelInstallServiceBase,
|
||||||
|
mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase:
|
||||||
|
return ModelManagerService(
|
||||||
|
store=mm2_record_store,
|
||||||
|
install=mm2_installer,
|
||||||
|
load=mm2_loader
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||||
@ -252,22 +308,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
|||||||
return sess
|
return sess
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase:
|
|
||||||
logger = InvokeAILogger.get_logger()
|
|
||||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
|
||||||
events = DummyEventService()
|
|
||||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
|
||||||
|
|
||||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
|
||||||
download_queue.start()
|
|
||||||
|
|
||||||
installer = ModelInstallService(
|
|
||||||
app_config=mm2_app_config,
|
|
||||||
record_store=store,
|
|
||||||
download_queue=download_queue,
|
|
||||||
event_bus=events,
|
|
||||||
session=mm2_session,
|
|
||||||
)
|
|
||||||
installer.start()
|
|
||||||
return installer
|
|
@ -19,7 +19,7 @@ from invokeai.backend.model_manager.metadata import (
|
|||||||
UnknownMetadataException,
|
UnknownMetadataException,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.util import select_hf_files
|
from invokeai.backend.model_manager.util import select_hf_files
|
||||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
|
from invokeai.backend.model_manager.util.libc_util import LibcUtil, Struct_mallinfo2
|
||||||
|
|
||||||
|
|
||||||
def test_libc_util_mallinfo2():
|
def test_libc_util_mallinfo2():
|
@ -5,8 +5,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.model_management.lora import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
|
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
@ -1,8 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from invokeai.backend.model_management.libc_util import Struct_mallinfo2
|
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
|
||||||
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
|
|
||||||
|
|
||||||
def test_memory_snapshot_capture():
|
def test_memory_snapshot_capture():
|
||||||
"""Smoke test of MemorySnapshot.capture()."""
|
"""Smoke test of MemorySnapshot.capture()."""
|
||||||
@ -26,6 +25,7 @@ snapshots = [
|
|||||||
def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
|
def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
|
||||||
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
|
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
|
||||||
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
|
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
|
||||||
|
print(msg)
|
||||||
|
|
||||||
expected_lines = 0
|
expected_lines = 0
|
||||||
if snapshot_1 is not None and snapshot_2 is not None:
|
if snapshot_1 is not None and snapshot_2 is not None:
|
@ -1,7 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import _no_op, skip_torch_weight_init
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
@ -1,7 +1,2 @@
|
|||||||
# conftest.py is a special pytest file. Fixtures defined in this file will be accessible to all tests in this directory
|
# conftest.py is a special pytest file. Fixtures defined in this file will be accessible to all tests in this directory
|
||||||
# without needing to explicitly import them. (https://docs.pytest.org/en/6.2.x/fixture.html)
|
# without needing to explicitly import them. (https://docs.pytest.org/en/6.2.x/fixture.html)
|
||||||
|
|
||||||
|
|
||||||
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
|
|
||||||
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
|
|
||||||
from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401
|
|
||||||
|
Loading…
Reference in New Issue
Block a user