loaders for main, controlnet, ip-adapter, clipvision and t2i

This commit is contained in:
Lincoln Stein
2024-02-04 17:23:10 -05:00
committed by psychedelicious
parent 8ba5360269
commit 67eb715093
32 changed files with 1123 additions and 159 deletions

View File

@ -10,12 +10,12 @@ from diffusers import ModelMixin
from diffusers.configuration_utils import ConfigMixin
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs, calc_model_size_by_data
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
@ -38,7 +38,7 @@ class ModelLoader(ModelLoaderBase):
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
@ -47,7 +47,6 @@ class ModelLoader(ModelLoaderBase):
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._torch_dtype = torch_dtype(choose_torch_device())
self._size: Optional[int] = None # model size
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
@ -63,9 +62,7 @@ class ModelLoader(ModelLoaderBase):
if model_config.type == "main" and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
submodel_type = None
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
if not model_path.exists():
raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}")
@ -74,13 +71,12 @@ class ModelLoader(ModelLoaderBase):
locker = self._load_if_needed(model_config, model_path, submodel_type)
return LoadedModel(config=model_config, locker=locker)
# IMPORTANT: This needs to be overridden in the StableDiffusion subclass so as to handle vae overrides
# and submodels!!!!
def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, bool]:
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
model_base = self._app_config.models_path
return ((model_base / config.path).resolve(), False)
result = (model_base / config.path).resolve(), config, submodel_type
return result
def _convert_if_needed(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
@ -90,7 +86,7 @@ class ModelLoader(ModelLoaderBase):
if not self._needs_conversion(config, model_path, cache_path):
return cache_path if cache_path.exists() else model_path
self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type))
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
@ -114,6 +110,7 @@ class ModelLoader(ModelLoaderBase):
config.key,
submodel_type=submodel_type,
model=loaded_model,
size=calc_model_size_by_data(loaded_model),
)
return self._ram_cache.get(config.key, submodel_type)
@ -128,17 +125,6 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
)
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
raise NotImplementedError
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
raise NotImplementedError
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name)
@ -161,3 +147,17 @@ class ModelLoader(ModelLoaderBase):
config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config["_class_name"]
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
# This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
raise NotImplementedError
# This needs to be implemented in the subclass
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
raise NotImplementedError