mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make model manager v2 ready for PR review
- Replace legacy model manager service with the v2 manager. - Update invocations to use new load interface. - Fixed many but not all type checking errors in the invocations. Most were unrelated to model manager - Updated routes. All the new routes live under the route tag `model_manager_v2`. To avoid confusion with the old routes, they have the URL prefix `/api/v2/models`. The old routes have been de-registered. - Added a pytest for the loader. - Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
committed by
Brandon Rising
parent
721ff58e44
commit
d56337f2d8
@ -64,7 +64,7 @@ class ModelPatcher:
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
@ -307,7 +307,7 @@ class ONNXModelPatcher:
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
@ -8,8 +8,8 @@ from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend import SilenceWarnings
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
|
@ -8,7 +8,6 @@ from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||
from invokeai.backend.model_management.models.base import calc_model_size_by_data
|
||||
|
||||
from .resampler import Resampler
|
||||
|
||||
@ -124,6 +123,9 @@ class IPAdapter:
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def calc_size(self):
|
||||
# workaround for circular import
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
|
||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
|
@ -21,7 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
"""
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union, Class
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
@ -335,7 +335,7 @@ class ModelConfigFactory(object):
|
||||
cls,
|
||||
model_data: Union[Dict[str, Any], AnyModelConfig],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type[Class]] = None,
|
||||
dest_class: Optional[Type[ModelConfigBase]] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
@ -347,14 +347,17 @@ class ModelConfigFactory(object):
|
||||
:param dest_class: The config class to be returned. If not provided, will
|
||||
be selected automatically.
|
||||
"""
|
||||
model: Optional[ModelConfigBase] = None
|
||||
if isinstance(model_data, ModelConfigBase):
|
||||
model = model_data
|
||||
elif dest_class:
|
||||
model = dest_class.validate_python(model_data)
|
||||
model = dest_class.model_validate(model_data)
|
||||
else:
|
||||
model = AnyModelConfigValidator.validate_python(model_data)
|
||||
# mypy doesn't typecheck TypeAdapters well?
|
||||
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
if timestamp:
|
||||
model.last_modified = timestamp
|
||||
return model
|
||||
return model # type: ignore
|
||||
|
@ -18,8 +18,16 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModel, VaeCheckpointConfig, VaeDiffusersConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
VaeCheckpointConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -32,7 +40,7 @@ class LoadedModel:
|
||||
config: AnyModelConfig
|
||||
locker: ModelLockerBase
|
||||
|
||||
def __enter__(self) -> AnyModel: # I think load_file() always returns a dict
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self.locker.lock()
|
||||
return self.model
|
||||
@ -171,6 +179,10 @@ class AnyModelLoader:
|
||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||
cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}")
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
|
||||
)
|
||||
cls._registry[key] = subclass
|
||||
return subclass
|
||||
|
||||
|
@ -169,7 +169,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
raise NotImplementedError
|
||||
|
||||
# This needs to be implemented in the subclass
|
||||
|
@ -246,7 +246,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
# These attributes are not in the base ModelMixin class but in derived classes.
|
||||
# These attributes are not in the base ModelMixin class but in various derived classes.
|
||||
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
||||
|
@ -35,28 +35,28 @@ class ControlnetLoader(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert hasattr(config, "config")
|
||||
config_file = config.config
|
||||
|
||||
if weights_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
convert_controlnet_to_diffusers(
|
||||
weights_path,
|
||||
model_path,
|
||||
output_path,
|
||||
original_config_file=self._app_config.root_path / config_file,
|
||||
image_size=512,
|
||||
scan_needed=True,
|
||||
from_safetensors=weights_path.suffix == ".safetensors",
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
)
|
||||
return output_path
|
||||
|
@ -12,8 +12,9 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
|
||||
from ..load_base import AnyModelLoader
|
||||
from ..load_default import ModelLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
|
@ -65,7 +65,7 @@ class StableDiffusionDiffusersModel(ModelLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
variant = config.variant
|
||||
base = config.base
|
||||
@ -75,9 +75,9 @@ class StableDiffusionDiffusersModel(ModelLoader):
|
||||
|
||||
config_file = config.config
|
||||
|
||||
self._logger.info(f"Converting {weights_path} to diffusers format")
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
convert_ckpt_to_diffusers(
|
||||
weights_path,
|
||||
model_path,
|
||||
output_path,
|
||||
model_type=self.model_base_to_model_type[base],
|
||||
model_version=base,
|
||||
@ -86,7 +86,7 @@ class StableDiffusionDiffusersModel(ModelLoader):
|
||||
extract_ema=True,
|
||||
scan_needed=True,
|
||||
pipeline_class=pipeline_class,
|
||||
from_safetensors=weights_path.suffix == ".safetensors",
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
precision=self._torch_dtype,
|
||||
load_safety_checker=False,
|
||||
)
|
||||
|
@ -37,7 +37,7 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
# TO DO: check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
@ -46,10 +46,10 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
||||
)
|
||||
|
||||
if weights_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
|
@ -65,7 +65,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var
|
||||
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
|
||||
other_files = set(all_files) - fp16_files - bit8_files
|
||||
|
||||
if variant is None:
|
||||
if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
|
||||
files = other_files
|
||||
elif variant == "fp16":
|
||||
files = fp16_files
|
||||
|
@ -22,11 +22,12 @@ Example usage:
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
default_logger: Logger = InvokeAILogger.get_logger()
|
||||
|
@ -1 +1,3 @@
|
||||
from .schedulers import SCHEDULER_MAP # noqa: F401
|
||||
|
||||
__all__ = ["SCHEDULER_MAP"]
|
||||
|
Reference in New Issue
Block a user