fix(mm): misc typing fixes for model loaders

This commit is contained in:
psychedelicious 2024-03-01 13:39:06 +11:00
parent c561cd751f
commit e426096d32
7 changed files with 22 additions and 16 deletions

View File

@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
ModelRepoVariant,
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
if model_config.type == "main" and not submodel_type:
if model_config.type is ModelType.Main and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
@ -80,7 +81,7 @@ class ModelLoader(ModelLoaderBase):
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False
def _load_if_needed(
@ -119,7 +120,7 @@ class ModelLoader(ModelLoaderBase):
return calc_model_size_by_fs(
model_path=model_path,
subfolder=submodel_type.value if submodel_type else None,
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
)
# This needs to be implemented in subclasses that handle checkpoints

View File

@ -59,6 +59,7 @@ class ModelLoaderRegistryBase(ABC):
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry:
"""
This class allows model loaders to register their type, base and format.

View File

@ -3,8 +3,8 @@
from pathlib import Path
import safetensors
import torch
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import (
AnyModelConfig,
@ -12,6 +12,7 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from .. import ModelLoaderRegistry
@ -20,7 +21,7 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlnetLoader(GenericDiffusersLoader):
class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
@ -37,13 +38,13 @@ class ControlnetLoader(GenericDiffusersLoader):
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
else:
assert hasattr(config, "config")
assert isinstance(config, CheckpointConfigBase)
config_file = config.config
if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
checkpoint = safetensors_load_file(model_path, device="cpu")
else:
checkpoint = torch.load(model_path, map_location="cpu")

View File

@ -3,7 +3,7 @@
import sys
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Optional
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
@ -42,6 +42,7 @@ class GenericDiffusersLoader(ModelLoader):
# TO DO: Add exception handling
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
result = None
if submodel_type:
try:
config = self._load_diffusers_config(model_path, config_name="model_index.json")
@ -65,6 +66,7 @@ class GenericDiffusersLoader(ModelLoader):
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
assert result is not None
return result
# TO DO: Add exception handling
@ -76,7 +78,7 @@ class GenericDiffusersLoader(ModelLoader):
result: ModelMixin = getattr(res_type, class_name)
return result
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name)
@ -84,8 +86,8 @@ class ConfigLoader(ConfigMixin):
"""Subclass of ConfigMixin for loading diffusers configuration files."""
@classmethod
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride]
"""Load a diffusrs ConfigMixin configuration."""
cls.config_name = kwargs.pop("config_name")
# Diffusers doesn't provide typing info
# TODO(psyche): the types on this diffusers method are not correct
return super().load_config(*args, **kwargs) # type: ignore

View File

@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
if submodel_type is not None:
raise ValueError("There are no submodels in an IP-Adapter model.")
model = build_ip_adapter(
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
device=torch.device("cpu"),
dtype=self._torch_dtype,
)

View File

@ -4,7 +4,8 @@
from pathlib import Path
from typing import Optional
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from invokeai.backend.model_manager import (
AnyModel,

View File

@ -3,9 +3,9 @@
from pathlib import Path
import safetensors
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import (
AnyModelConfig,
@ -47,7 +47,7 @@ class VaeLoader(GenericDiffusersLoader):
)
if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
checkpoint = safetensors_load_file(model_path, device="cpu")
else:
checkpoint = torch.load(model_path, map_location="cpu")