mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into bugfix/autodetect-sdxl-ckpt-config
This commit is contained in:
@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import os
|
||||
import textwrap
|
||||
import yaml
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
from shutil import rmtree, move
|
||||
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
@ -259,6 +259,7 @@ from .models import (
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
DuplicateModelException,
|
||||
ModelBase,
|
||||
)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
@ -361,7 +362,7 @@ class ModelManager(object):
|
||||
if model_key.startswith("_"):
|
||||
continue
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
# alias for config file
|
||||
model_config["model_format"] = model_config.pop("format")
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
@ -381,18 +382,24 @@ class ModelManager(object):
|
||||
# causing otherwise unreferenced models to be removed from memory
|
||||
self._read_models()
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
Given a model name, returns True if it is a valid identifier.
|
||||
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
:param model_type: ModelType enum indicating the type of model to return
|
||||
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||
:param rescan: if True, scan_models_directory
|
||||
"""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
return model_key in self.models
|
||||
exists = model_key in self.models
|
||||
|
||||
# if model not found try to find it (maybe file just pasted)
|
||||
if rescan and not exists:
|
||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
|
||||
|
||||
return exists
|
||||
|
||||
@classmethod
|
||||
def create_key(
|
||||
@ -443,39 +450,32 @@ class ModelManager(object):
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
:param model_type: ModelType enum indicating the type of model to return
|
||||
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||
:param submode_typel: an ModelType enum indicating the portion of
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
# if model not found try to find it (maybe file just pasted)
|
||||
if model_key not in self.models:
|
||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||
if model_key not in self.models:
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
if not self.model_exists(model_name, base_model, model_type, rescan=True):
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
model_config = self.models[model_key]
|
||||
model_path = self.resolve_model_path(model_config.path)
|
||||
model_config = self._get_model_config(base_model, model_name, model_type)
|
||||
|
||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||
|
||||
if is_submodel_override:
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
|
||||
if not model_path.exists():
|
||||
if model_class.save_to_config:
|
||||
self.models[model_key].error = ModelError.NotFound
|
||||
raise Exception(f'Files for model "{model_key}" not found')
|
||||
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
|
||||
|
||||
else:
|
||||
self.models.pop(model_key, None)
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
# vae/movq override
|
||||
# TODO:
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
override_path = getattr(model_config, submodel_type)
|
||||
if override_path:
|
||||
model_path = self.resolve_path(override_path)
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
|
||||
|
||||
# TODO: path
|
||||
# TODO: is it accurate to use path as id
|
||||
@ -513,6 +513,55 @@ class ModelManager(object):
|
||||
_cache=self.cache,
|
||||
)
|
||||
|
||||
def _get_model_path(
|
||||
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
|
||||
) -> (Path, bool):
|
||||
"""Extract a model's filesystem path from its config.
|
||||
|
||||
:return: The fully qualified Path of the module (or submodule).
|
||||
"""
|
||||
model_path = model_config.path
|
||||
is_submodel_override = False
|
||||
|
||||
# Does the config explicitly override the submodel?
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
submodel_path = getattr(model_config, submodel_type)
|
||||
if submodel_path is not None:
|
||||
model_path = getattr(model_config, submodel_type)
|
||||
is_submodel_override = True
|
||||
|
||||
model_path = self.resolve_model_path(model_path)
|
||||
return model_path, is_submodel_override
|
||||
|
||||
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
|
||||
"""Get a model's config object."""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
try:
|
||||
model_config = self.models[model_key]
|
||||
except KeyError:
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
return model_config
|
||||
|
||||
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
||||
"""Get the concrete implementation class for a specific model type."""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
return model_class
|
||||
|
||||
def _instantiate(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> ModelBase:
|
||||
"""Make a new instance of this model, without loading it."""
|
||||
model_config = self._get_model_config(base_model, model_name, model_type)
|
||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||
# FIXME: do non-overriden submodels get the right class?
|
||||
constructor = self._get_implementation(base_model, model_type)
|
||||
instance = constructor(model_path, base_model, model_type)
|
||||
return instance
|
||||
|
||||
def model_info(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -546,9 +595,10 @@ class ModelManager(object):
|
||||
the combined format of the list_models() method.
|
||||
"""
|
||||
models = self.list_models(base_model, model_type, model_name)
|
||||
if len(models) > 1:
|
||||
if len(models) >= 1:
|
||||
return models[0]
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
@ -660,7 +710,7 @@ class ModelManager(object):
|
||||
if path := model_attributes.get("path"):
|
||||
model_attributes["path"] = str(self.relative_model_path(Path(path)))
|
||||
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
model_config = model_class.create_config(**model_attributes)
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
@ -851,7 +901,7 @@ class ModelManager(object):
|
||||
|
||||
for model_key, model_config in self.models.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
if model_class.save_to_config:
|
||||
# TODO: or exclude_unset better fits here?
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||
@ -909,7 +959,7 @@ class ModelManager(object):
|
||||
|
||||
model_path = self.resolve_model_path(model_config.path).absolute()
|
||||
if not model_path.exists():
|
||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
||||
if model_class.save_to_config:
|
||||
model_config.error = ModelError.NotFound
|
||||
self.models.pop(model_key, None)
|
||||
@ -925,7 +975,7 @@ class ModelManager(object):
|
||||
for cur_model_type in ModelType:
|
||||
if model_type is not None and cur_model_type != model_type:
|
||||
continue
|
||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
||||
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
|
||||
|
||||
if not models_dir.exists():
|
||||
|
Reference in New Issue
Block a user