mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(mm): misc typing fixes for model loaders
This commit is contained in:
@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||||
@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
:param submodel_type: an ModelType enum indicating the portion of
|
:param submodel_type: an ModelType enum indicating the portion of
|
||||||
the model to retrieve (e.g. ModelType.Vae)
|
the model to retrieve (e.g. ModelType.Vae)
|
||||||
"""
|
"""
|
||||||
if model_config.type == "main" and not submodel_type:
|
if model_config.type is ModelType.Main and not submodel_type:
|
||||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||||
|
|
||||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||||
@ -80,7 +81,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||||
return self._convert_model(config, model_path, cache_path)
|
return self._convert_model(config, model_path, cache_path)
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _load_if_needed(
|
def _load_if_needed(
|
||||||
@ -119,7 +120,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
return calc_model_size_by_fs(
|
return calc_model_size_by_fs(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
subfolder=submodel_type.value if submodel_type else None,
|
subfolder=submodel_type.value if submodel_type else None,
|
||||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This needs to be implemented in subclasses that handle checkpoints
|
# This needs to be implemented in subclasses that handle checkpoints
|
||||||
|
@ -59,6 +59,7 @@ class ModelLoaderRegistryBase(ABC):
|
|||||||
|
|
||||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
||||||
|
|
||||||
|
|
||||||
class ModelLoaderRegistry:
|
class ModelLoaderRegistry:
|
||||||
"""
|
"""
|
||||||
This class allows model loaders to register their type, base and format.
|
This class allows model loaders to register their type, base and format.
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import safetensors
|
|
||||||
import torch
|
import torch
|
||||||
|
from safetensors.torch import load_file as safetensors_load_file
|
||||||
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -12,6 +12,7 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||||
|
|
||||||
from .. import ModelLoaderRegistry
|
from .. import ModelLoaderRegistry
|
||||||
@ -20,7 +21,7 @@ from .generic_diffusers import GenericDiffusersLoader
|
|||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||||
class ControlnetLoader(GenericDiffusersLoader):
|
class ControlNetLoader(GenericDiffusersLoader):
|
||||||
"""Class to load ControlNet models."""
|
"""Class to load ControlNet models."""
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
@ -37,13 +38,13 @@ class ControlnetLoader(GenericDiffusersLoader):
|
|||||||
|
|
||||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
|
||||||
else:
|
else:
|
||||||
assert hasattr(config, "config")
|
assert isinstance(config, CheckpointConfigBase)
|
||||||
config_file = config.config
|
config_file = config.config
|
||||||
|
|
||||||
if model_path.suffix == ".safetensors":
|
if model_path.suffix == ".safetensors":
|
||||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(model_path, map_location="cpu")
|
checkpoint = torch.load(model_path, map_location="cpu")
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
@ -42,6 +42,7 @@ class GenericDiffusersLoader(ModelLoader):
|
|||||||
# TO DO: Add exception handling
|
# TO DO: Add exception handling
|
||||||
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||||
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
|
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
|
||||||
|
result = None
|
||||||
if submodel_type:
|
if submodel_type:
|
||||||
try:
|
try:
|
||||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||||
@ -65,6 +66,7 @@ class GenericDiffusersLoader(ModelLoader):
|
|||||||
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||||
|
assert result is not None
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# TO DO: Add exception handling
|
# TO DO: Add exception handling
|
||||||
@ -76,7 +78,7 @@ class GenericDiffusersLoader(ModelLoader):
|
|||||||
result: ModelMixin = getattr(res_type, class_name)
|
result: ModelMixin = getattr(res_type, class_name)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]:
|
||||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||||
|
|
||||||
|
|
||||||
@ -84,8 +86,8 @@ class ConfigLoader(ConfigMixin):
|
|||||||
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride]
|
||||||
"""Load a diffusrs ConfigMixin configuration."""
|
"""Load a diffusrs ConfigMixin configuration."""
|
||||||
cls.config_name = kwargs.pop("config_name")
|
cls.config_name = kwargs.pop("config_name")
|
||||||
# Diffusers doesn't provide typing info
|
# TODO(psyche): the types on this diffusers method are not correct
|
||||||
return super().load_config(*args, **kwargs) # type: ignore
|
return super().load_config(*args, **kwargs) # type: ignore
|
||||||
|
@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
|||||||
if submodel_type is not None:
|
if submodel_type is not None:
|
||||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||||
model = build_ip_adapter(
|
model = build_ip_adapter(
|
||||||
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
|
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
||||||
device=torch.device("cpu"),
|
device=torch.device("cpu"),
|
||||||
dtype=self._torch_dtype,
|
dtype=self._torch_dtype,
|
||||||
)
|
)
|
||||||
|
@ -4,7 +4,8 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||||
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
|
@ -3,9 +3,9 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import safetensors
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
from safetensors.torch import load_file as safetensors_load_file
|
||||||
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -47,7 +47,7 @@ class VaeLoader(GenericDiffusersLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_path.suffix == ".safetensors":
|
if model_path.suffix == ".safetensors":
|
||||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(model_path, map_location="cpu")
|
checkpoint = torch.load(model_path, map_location="cpu")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user