fix(mm): misc typing fixes for model loaders

This commit is contained in:
psychedelicious
2024-03-01 13:39:06 +11:00
parent c561cd751f
commit e426096d32
7 changed files with 22 additions and 16 deletions

View File

@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
ModelRepoVariant, ModelRepoVariant,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
:param submodel_type: an ModelType enum indicating the portion of :param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae) the model to retrieve (e.g. ModelType.Vae)
""" """
if model_config.type == "main" and not submodel_type: if model_config.type is ModelType.Main and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model") raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type) model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
@ -80,7 +81,7 @@ class ModelLoader(ModelLoaderBase):
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path) return self._convert_model(config, model_path, cache_path)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False return False
def _load_if_needed( def _load_if_needed(
@ -119,7 +120,7 @@ class ModelLoader(ModelLoaderBase):
return calc_model_size_by_fs( return calc_model_size_by_fs(
model_path=model_path, model_path=model_path,
subfolder=submodel_type.value if submodel_type else None, subfolder=submodel_type.value if submodel_type else None,
variant=config.repo_variant if hasattr(config, "repo_variant") else None, variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
) )
# This needs to be implemented in subclasses that handle checkpoints # This needs to be implemented in subclasses that handle checkpoints

View File

@ -59,6 +59,7 @@ class ModelLoaderRegistryBase(ABC):
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase) TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry: class ModelLoaderRegistry:
""" """
This class allows model loaders to register their type, base and format. This class allows model loaders to register their type, base and format.

View File

@ -3,8 +3,8 @@
from pathlib import Path from pathlib import Path
import safetensors
import torch import torch
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
@ -12,6 +12,7 @@ from invokeai.backend.model_manager import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from .. import ModelLoaderRegistry from .. import ModelLoaderRegistry
@ -20,7 +21,7 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlnetLoader(GenericDiffusersLoader): class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models.""" """Class to load ControlNet models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
@ -37,13 +38,13 @@ class ControlnetLoader(GenericDiffusersLoader):
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}") raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
else: else:
assert hasattr(config, "config") assert isinstance(config, CheckpointConfigBase)
config_file = config.config config_file = config.config
if model_path.suffix == ".safetensors": if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu") checkpoint = safetensors_load_file(model_path, device="cpu")
else: else:
checkpoint = torch.load(model_path, map_location="cpu") checkpoint = torch.load(model_path, map_location="cpu")

View File

@ -3,7 +3,7 @@
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Optional
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
@ -42,6 +42,7 @@ class GenericDiffusersLoader(ModelLoader):
# TO DO: Add exception handling # TO DO: Add exception handling
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load.""" """Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
result = None
if submodel_type: if submodel_type:
try: try:
config = self._load_diffusers_config(model_path, config_name="model_index.json") config = self._load_diffusers_config(model_path, config_name="model_index.json")
@ -65,6 +66,7 @@ class GenericDiffusersLoader(ModelLoader):
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json") raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
except KeyError as e: except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
assert result is not None
return result return result
# TO DO: Add exception handling # TO DO: Add exception handling
@ -76,7 +78,7 @@ class GenericDiffusersLoader(ModelLoader):
result: ModelMixin = getattr(res_type, class_name) result: ModelMixin = getattr(res_type, class_name)
return result return result
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name) return ConfigLoader.load_config(model_path, config_name=config_name)
@ -84,8 +86,8 @@ class ConfigLoader(ConfigMixin):
"""Subclass of ConfigMixin for loading diffusers configuration files.""" """Subclass of ConfigMixin for loading diffusers configuration files."""
@classmethod @classmethod
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride]
"""Load a diffusrs ConfigMixin configuration.""" """Load a diffusrs ConfigMixin configuration."""
cls.config_name = kwargs.pop("config_name") cls.config_name = kwargs.pop("config_name")
# Diffusers doesn't provide typing info # TODO(psyche): the types on this diffusers method are not correct
return super().load_config(*args, **kwargs) # type: ignore return super().load_config(*args, **kwargs) # type: ignore

View File

@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
if submodel_type is not None: if submodel_type is not None:
raise ValueError("There are no submodels in an IP-Adapter model.") raise ValueError("There are no submodels in an IP-Adapter model.")
model = build_ip_adapter( model = build_ip_adapter(
ip_adapter_ckpt_path=model_path / "ip_adapter.bin", ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
device=torch.device("cpu"), device=torch.device("cpu"),
dtype=self._torch_dtype, dtype=self._torch_dtype,
) )

View File

@ -4,7 +4,8 @@
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,

View File

@ -3,9 +3,9 @@
from pathlib import Path from pathlib import Path
import safetensors
import torch import torch
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
@ -47,7 +47,7 @@ class VaeLoader(GenericDiffusersLoader):
) )
if model_path.suffix == ".safetensors": if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu") checkpoint = safetensors_load_file(model_path, device="cpu")
else: else:
checkpoint = torch.load(model_path, map_location="cpu") checkpoint = torch.load(model_path, map_location="cpu")