mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Save models on rescan, uncache model on edit/delete, fixes
This commit is contained in:
parent
26090011c4
commit
740c05a0bb
@ -140,7 +140,6 @@ class ModelManagerServiceBase(ABC):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
delete_files: bool = False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Delete the named model from configuration. If delete_files is true,
|
Delete the named model from configuration. If delete_files is true,
|
||||||
@ -149,91 +148,6 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def import_diffuser_model(
|
|
||||||
repo_or_path: Union[str, Path],
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
vae: Optional[dict] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Install the indicated diffuser model and returns True if successful.
|
|
||||||
|
|
||||||
"repo_or_path" can be either a repo-id or a path-like object corresponding to the
|
|
||||||
top of a downloaded diffusers directory.
|
|
||||||
|
|
||||||
You can optionally provide a model name and/or description. If not provided,
|
|
||||||
then these will be derived from the repo name. Call commit() to write to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def import_lora(
|
|
||||||
self,
|
|
||||||
path: Path,
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates an entry for the indicated lora file. Call
|
|
||||||
mgr.commit() to write out the configuration to models.yaml
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def import_embedding(
|
|
||||||
self,
|
|
||||||
path: Path,
|
|
||||||
model_name: str=None,
|
|
||||||
description: str=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates an entry for the indicated textual inversion embedding file.
|
|
||||||
Call commit() to write out the configuration to models.yaml
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def heuristic_import(
|
|
||||||
self,
|
|
||||||
path_url_or_repo: str,
|
|
||||||
model_name: str = None,
|
|
||||||
description: str = None,
|
|
||||||
model_config_file: Path = None,
|
|
||||||
commit_to_conf: Path = None,
|
|
||||||
config_file_callback: Callable[[Path], Path] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Accept a string which could be:
|
|
||||||
- a HF diffusers repo_id
|
|
||||||
- a URL pointing to a legacy .ckpt or .safetensors file
|
|
||||||
- a local path pointing to a legacy .ckpt or .safetensors file
|
|
||||||
- a local directory containing .ckpt and .safetensors files
|
|
||||||
- a local directory containing a diffusers model
|
|
||||||
|
|
||||||
After determining the nature of the model and downloading it
|
|
||||||
(if necessary), the file is probed to determine the correct
|
|
||||||
configuration file (if needed) and it is imported.
|
|
||||||
|
|
||||||
The model_name and/or description can be provided. If not, they will
|
|
||||||
be generated automatically.
|
|
||||||
|
|
||||||
If commit_to_conf is provided, the newly loaded model will be written
|
|
||||||
to the `models.yaml` file at the indicated path. Otherwise, the changes
|
|
||||||
will only remain in memory.
|
|
||||||
|
|
||||||
The routine will do its best to figure out the config file
|
|
||||||
needed to convert legacy checkpoint file, but if it can't it
|
|
||||||
will call the config_file_callback routine, if provided. The
|
|
||||||
callback accepts a single argument, the Path to the checkpoint
|
|
||||||
file, and returns a Path to the config file to use.
|
|
||||||
|
|
||||||
The (potentially derived) name of the model is returned on
|
|
||||||
success, or None on failure. When multiple models are added
|
|
||||||
from a directory, only the last imported one is returned.
|
|
||||||
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def commit(self, conf_file: Path = None) -> None:
|
def commit(self, conf_file: Path = None) -> None:
|
||||||
"""
|
"""
|
||||||
@ -424,103 +338,13 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
delete_files: bool = False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Delete the named model from configuration. If delete_files is true,
|
Delete the named model from configuration. If delete_files is true,
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
as well. Call commit() to write to disk.
|
as well. Call commit() to write to disk.
|
||||||
"""
|
"""
|
||||||
self.mgr.del_model(model_name, base_model, model_type, delete_files)
|
self.mgr.del_model(model_name, base_model, model_type)
|
||||||
|
|
||||||
def import_diffuser_model(
|
|
||||||
self,
|
|
||||||
repo_or_path: Union[str, Path],
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
vae: Optional[dict] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Install the indicated diffuser model and returns True if successful.
|
|
||||||
|
|
||||||
"repo_or_path" can be either a repo-id or a path-like object corresponding to the
|
|
||||||
top of a downloaded diffusers directory.
|
|
||||||
|
|
||||||
You can optionally provide a model name and/or description. If not provided,
|
|
||||||
then these will be derived from the repo name. Call commit() to write to disk.
|
|
||||||
"""
|
|
||||||
return self.mgr.import_diffuser_model(repo_or_path, model_name, description, vae)
|
|
||||||
|
|
||||||
def import_lora(
|
|
||||||
self,
|
|
||||||
path: Path,
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates an entry for the indicated lora file. Call
|
|
||||||
mgr.commit() to write out the configuration to models.yaml
|
|
||||||
"""
|
|
||||||
self.mgr.import_lora(path, model_name, description)
|
|
||||||
|
|
||||||
def import_embedding(
|
|
||||||
self,
|
|
||||||
path: Path,
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates an entry for the indicated textual inversion embedding file.
|
|
||||||
Call commit() to write out the configuration to models.yaml
|
|
||||||
"""
|
|
||||||
self.mgr.import_embedding(path, model_name, description)
|
|
||||||
|
|
||||||
def heuristic_import(
|
|
||||||
self,
|
|
||||||
path_url_or_repo: str,
|
|
||||||
model_name: str = None,
|
|
||||||
description: str = None,
|
|
||||||
model_config_file: Optional[Path] = None,
|
|
||||||
commit_to_conf: Optional[Path] = None,
|
|
||||||
config_file_callback: Optional[Callable[[Path], Path]] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Accept a string which could be:
|
|
||||||
- a HF diffusers repo_id
|
|
||||||
- a URL pointing to a legacy .ckpt or .safetensors file
|
|
||||||
- a local path pointing to a legacy .ckpt or .safetensors file
|
|
||||||
- a local directory containing .ckpt and .safetensors files
|
|
||||||
- a local directory containing a diffusers model
|
|
||||||
|
|
||||||
After determining the nature of the model and downloading it
|
|
||||||
(if necessary), the file is probed to determine the correct
|
|
||||||
configuration file (if needed) and it is imported.
|
|
||||||
|
|
||||||
The model_name and/or description can be provided. If not, they will
|
|
||||||
be generated automatically.
|
|
||||||
|
|
||||||
If commit_to_conf is provided, the newly loaded model will be written
|
|
||||||
to the `models.yaml` file at the indicated path. Otherwise, the changes
|
|
||||||
will only remain in memory.
|
|
||||||
|
|
||||||
The routine will do its best to figure out the config file
|
|
||||||
needed to convert legacy checkpoint file, but if it can't it
|
|
||||||
will call the config_file_callback routine, if provided. The
|
|
||||||
callback accepts a single argument, the Path to the checkpoint
|
|
||||||
file, and returns a Path to the config file to use.
|
|
||||||
|
|
||||||
The (potentially derived) name of the model is returned on
|
|
||||||
success, or None on failure. When multiple models are added
|
|
||||||
from a directory, only the last imported one is returned.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return self.mgr.heuristic_import(
|
|
||||||
path_url_or_repo,
|
|
||||||
model_name,
|
|
||||||
description,
|
|
||||||
model_config_file,
|
|
||||||
commit_to_conf,
|
|
||||||
config_file_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def commit(self, conf_file: Optional[Path]=None):
|
def commit(self, conf_file: Optional[Path]=None):
|
||||||
|
@ -47,10 +47,6 @@ class ModelCache(object):
|
|||||||
"Forward declaration"
|
"Forward declaration"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class SDModelInfo(object):
|
|
||||||
"""Forward declaration"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
class _CacheRecord:
|
class _CacheRecord:
|
||||||
size: int
|
size: int
|
||||||
model: Any
|
model: Any
|
||||||
@ -106,7 +102,7 @@ class ModelCache(object):
|
|||||||
#max_cache_size = 9999
|
#max_cache_size = 9999
|
||||||
execution_device = torch.device('cuda')
|
execution_device = torch.device('cuda')
|
||||||
|
|
||||||
self.model_infos: Dict[str, SDModelInfo] = dict()
|
self.model_infos: Dict[str, ModelBase] = dict()
|
||||||
self.lazy_offloading = lazy_offloading
|
self.lazy_offloading = lazy_offloading
|
||||||
#self.sequential_offload: bool=sequential_offload
|
#self.sequential_offload: bool=sequential_offload
|
||||||
self.precision: torch.dtype=precision
|
self.precision: torch.dtype=precision
|
||||||
@ -225,17 +221,16 @@ class ModelCache(object):
|
|||||||
self.cache = cache
|
self.cache = cache
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.cache_entry = self.cache._cached_models[self.key]
|
||||||
|
|
||||||
def __enter__(self) -> Any:
|
def __enter__(self) -> Any:
|
||||||
if not hasattr(self.model, 'to'):
|
if not hasattr(self.model, 'to'):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
cache_entry = self.cache._cached_models[self.key]
|
|
||||||
|
|
||||||
# NOTE that the model has to have the to() method in order for this
|
# NOTE that the model has to have the to() method in order for this
|
||||||
# code to move it into GPU!
|
# code to move it into GPU!
|
||||||
if self.gpu_load:
|
if self.gpu_load:
|
||||||
cache_entry.lock()
|
self.cache_entry.lock()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.cache.lazy_offloading:
|
if self.cache.lazy_offloading:
|
||||||
@ -251,14 +246,14 @@ class ModelCache(object):
|
|||||||
self.cache._print_cuda_stats()
|
self.cache._print_cuda_stats()
|
||||||
|
|
||||||
except:
|
except:
|
||||||
cache_entry.unlock()
|
self.cache_entry.unlock()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
# TODO: not fully understand
|
# TODO: not fully understand
|
||||||
# in the event that the caller wants the model in RAM, we
|
# in the event that the caller wants the model in RAM, we
|
||||||
# move it into CPU if it is in GPU and not locked
|
# move it into CPU if it is in GPU and not locked
|
||||||
elif cache_entry.loaded and not cache_entry.locked:
|
elif self.cache_entry.loaded and not self.cache_entry.locked:
|
||||||
self.model.to(self.cache.storage_device)
|
self.model.to(self.cache.storage_device)
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
@ -267,12 +262,16 @@ class ModelCache(object):
|
|||||||
if not hasattr(self.model, 'to'):
|
if not hasattr(self.model, 'to'):
|
||||||
return
|
return
|
||||||
|
|
||||||
cache_entry = self.cache._cached_models[self.key]
|
self.cache_entry.unlock()
|
||||||
cache_entry.unlock()
|
|
||||||
if not self.cache.lazy_offloading:
|
if not self.cache.lazy_offloading:
|
||||||
self.cache._offload_unlocked_models()
|
self.cache._offload_unlocked_models()
|
||||||
self.cache._print_cuda_stats()
|
self.cache._print_cuda_stats()
|
||||||
|
|
||||||
|
# TODO: should it be called untrack_model?
|
||||||
|
def uncache_model(self, cache_id: str):
|
||||||
|
with suppress(ValueError):
|
||||||
|
self._cache_stack.remove(cache_id)
|
||||||
|
self._cached_models.pop(cache_id, None)
|
||||||
|
|
||||||
def model_hash(
|
def model_hash(
|
||||||
self,
|
self,
|
||||||
|
@ -234,38 +234,6 @@ class ModelManager(object):
|
|||||||
|
|
||||||
logger: types.ModuleType = logger
|
logger: types.ModuleType = logger
|
||||||
|
|
||||||
# TODO:
|
|
||||||
def _convert_2_3_models(self, config: DictConfig):
|
|
||||||
for model_name, model_config in config.items():
|
|
||||||
if model_config["format"] == "diffusers":
|
|
||||||
pass
|
|
||||||
elif model_config["format"] == "ckpt":
|
|
||||||
|
|
||||||
if any(model_config["config"].endswith(file) for file in {
|
|
||||||
"v1-finetune.yaml",
|
|
||||||
"v1-finetune_style.yaml",
|
|
||||||
"v1-inference.yaml",
|
|
||||||
"v1-inpainting-inference.yaml",
|
|
||||||
"v1-m1-finetune.yaml",
|
|
||||||
}):
|
|
||||||
# copy as as sd1.5
|
|
||||||
pass
|
|
||||||
|
|
||||||
# ~99% accurate should be
|
|
||||||
elif model_config["config"].endswith("v2-inference-v.yaml"):
|
|
||||||
# copy as sd 2.x (768)
|
|
||||||
pass
|
|
||||||
|
|
||||||
# for real don't know how accurate it
|
|
||||||
elif model_config["config"].endswith("v2-inference.yaml"):
|
|
||||||
# copy as sd 2.x-base (512)
|
|
||||||
pass
|
|
||||||
|
|
||||||
else:
|
|
||||||
# TODO:
|
|
||||||
raise Exception("Unknown model")
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Union[Path, DictConfig, str],
|
config: Union[Path, DictConfig, str],
|
||||||
@ -290,11 +258,9 @@ class ModelManager(object):
|
|||||||
elif not isinstance(config, DictConfig):
|
elif not isinstance(config, DictConfig):
|
||||||
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
||||||
|
|
||||||
#if "__meta__" not in config:
|
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
|
||||||
# config = self._convert_2_3_models(config)
|
|
||||||
|
|
||||||
config_meta = ConfigMeta(**config.pop("__metadata__")) # TODO: naming
|
|
||||||
# TODO: metadata not found
|
# TODO: metadata not found
|
||||||
|
# TODO: version check
|
||||||
|
|
||||||
self.models = dict()
|
self.models = dict()
|
||||||
for model_key, model_config in config.items():
|
for model_key, model_config in config.items():
|
||||||
@ -462,6 +428,10 @@ class ModelManager(object):
|
|||||||
submodel=submodel_type,
|
submodel=submodel_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_key not in self.cache_keys:
|
||||||
|
self.cache_keys[model_key] = set()
|
||||||
|
self.cache_keys[model_key].add(model_context.key)
|
||||||
|
|
||||||
model_hash = "<NO_HASH>" # TODO:
|
model_hash = "<NO_HASH>" # TODO:
|
||||||
|
|
||||||
return ModelInfo(
|
return ModelInfo(
|
||||||
@ -578,12 +548,12 @@ class ModelManager(object):
|
|||||||
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
|
# TODO: test when ui implemented
|
||||||
def del_model(
|
def del_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
delete_files: bool = False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Delete the named model.
|
Delete the named model.
|
||||||
@ -598,29 +568,20 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: some legacy?
|
# note: it not garantie to release memory(model can has other references)
|
||||||
#if model_name in self.stack:
|
cache_ids = self.cache_keys.pop(model_key, [])
|
||||||
# self.stack.remove(model_name)
|
for cache_id in cache_ids:
|
||||||
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
if delete_files:
|
# if model inside invoke models folder - delete files
|
||||||
repo_id = model_cfg.get("repo_id", None)
|
if model_cfg.path.startswith("models/") or model_cfg.path.startswith("models\\"):
|
||||||
path = self._abs_path(model_cfg.get("path", None))
|
model_path = self.globals.root_dir / model_cfg.path
|
||||||
weights = self._abs_path(model_cfg.get("weights", None))
|
if model_path.isdir():
|
||||||
if "weights" in model_cfg:
|
shutil.rmtree(str(model_path))
|
||||||
weights = self._abs_path(model_cfg["weights"])
|
else:
|
||||||
self.logger.info(f"Deleting file {weights}")
|
model_path.unlink()
|
||||||
Path(weights).unlink(missing_ok=True)
|
|
||||||
|
|
||||||
elif "path" in model_cfg:
|
|
||||||
path = self._abs_path(model_cfg["path"])
|
|
||||||
self.logger.info(f"Deleting directory {path}")
|
|
||||||
rmtree(path, ignore_errors=True)
|
|
||||||
|
|
||||||
elif "repo_id" in model_cfg:
|
|
||||||
repo_id = model_cfg["repo_id"]
|
|
||||||
self.logger.info(f"Deleting the cached model directory for {repo_id}")
|
|
||||||
self._delete_model_from_cache(repo_id)
|
|
||||||
|
|
||||||
|
# TODO: test when ui implemented
|
||||||
def add_model(
|
def add_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -648,9 +609,10 @@ class ModelManager(object):
|
|||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
|
|
||||||
if clobber and model_key in self.cache_keys:
|
if clobber and model_key in self.cache_keys:
|
||||||
# TODO:
|
# note: it not garantie to release memory(model can has other references)
|
||||||
self.cache.uncache_model(self.cache_keys[model_key])
|
cache_ids = self.cache_keys.pop(model_key, [])
|
||||||
del self.cache_keys[model_key]
|
for cache_id in cache_ids:
|
||||||
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
def search_models(self, search_folder):
|
def search_models(self, search_folder):
|
||||||
self.logger.info(f"Finding Models In: {search_folder}")
|
self.logger.info(f"Finding Models In: {search_folder}")
|
||||||
@ -678,6 +640,8 @@ class ModelManager(object):
|
|||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
"""
|
"""
|
||||||
data_to_save = dict()
|
data_to_save = dict()
|
||||||
|
data_to_save["__metadata__"] = self.config_meta.dict()
|
||||||
|
|
||||||
for model_key, model_config in self.models.items():
|
for model_key, model_config in self.models.items():
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
@ -711,46 +675,9 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _delete_model_from_cache(cls,repo_id):
|
|
||||||
cache_info = scan_cache_dir(InvokeAIAppConfig.get_config().cache_dir)
|
|
||||||
|
|
||||||
# I'm sure there is a way to do this with comprehensions
|
|
||||||
# but the code quickly became incomprehensible!
|
|
||||||
hashes_to_delete = set()
|
|
||||||
for repo in cache_info.repos:
|
|
||||||
if repo.repo_id == repo_id:
|
|
||||||
for revision in repo.revisions:
|
|
||||||
hashes_to_delete.add(revision.commit_hash)
|
|
||||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
|
||||||
cls.logger.warning(
|
|
||||||
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
|
||||||
)
|
|
||||||
strategy.execute()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _abs_path(path: str | Path) -> Path:
|
|
||||||
globals = InvokeAIAppConfig.get_config()
|
|
||||||
if path is None or Path(path).is_absolute():
|
|
||||||
return path
|
|
||||||
return Path(globals.root_dir, path).resolve()
|
|
||||||
|
|
||||||
# This is not the same as global_resolve_path(), which prepends
|
|
||||||
# Globals.root.
|
|
||||||
def _resolve_path(
|
|
||||||
self, source: Union[str, Path], dest_directory: str
|
|
||||||
) -> Optional[Path]:
|
|
||||||
resolved_path = None
|
|
||||||
if str(source).startswith(("http:", "https:", "ftp:")):
|
|
||||||
dest_directory = self.globals.root_dir / dest_directory
|
|
||||||
dest_directory.mkdir(parents=True, exist_ok=True)
|
|
||||||
resolved_path = download_with_resume(str(source), dest_directory)
|
|
||||||
else:
|
|
||||||
resolved_path = self.globals.root_dir / source
|
|
||||||
return resolved_path
|
|
||||||
|
|
||||||
def scan_models_directory(self):
|
def scan_models_directory(self):
|
||||||
loaded_files = set()
|
loaded_files = set()
|
||||||
|
new_models_found = False
|
||||||
|
|
||||||
for model_key, model_config in list(self.models.items()):
|
for model_key, model_config in list(self.models.items()):
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
@ -783,3 +710,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
model_config: ModelConfigBase = model_class.probe_config(model_path)
|
model_config: ModelConfigBase = model_class.probe_config(model_path)
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
|
new_models_found = True
|
||||||
|
|
||||||
|
if new_models_found:
|
||||||
|
self.commit()
|
||||||
|
@ -3,11 +3,13 @@ import sys
|
|||||||
import typing
|
import typing
|
||||||
import inspect
|
import inspect
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
import torch
|
import torch
|
||||||
|
import safetensors.torch
|
||||||
from diffusers import DiffusionPipeline, ConfigMixin
|
from diffusers import DiffusionPipeline, ConfigMixin
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict, Optional, Type, Literal
|
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
class BaseModelType(str, Enum):
|
||||||
StableDiffusion1 = "sd-1"
|
StableDiffusion1 = "sd-1"
|
||||||
@ -52,6 +54,9 @@ class ModelConfigBase(BaseModel):
|
|||||||
# do not save to config
|
# do not save to config
|
||||||
error: Optional[ModelError] = Field(None, exclude=True)
|
error: Optional[ModelError] = Field(None, exclude=True)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
use_enum_values = True
|
||||||
|
|
||||||
|
|
||||||
class EmptyConfigLoader(ConfigMixin):
|
class EmptyConfigLoader(ConfigMixin):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -59,7 +64,18 @@ class EmptyConfigLoader(ConfigMixin):
|
|||||||
cls.config_name = kwargs.pop("config_name")
|
cls.config_name = kwargs.pop("config_name")
|
||||||
return super().load_config(*args, **kwargs)
|
return super().load_config(*args, **kwargs)
|
||||||
|
|
||||||
class ModelBase:
|
T_co = TypeVar('T_co', covariant=True)
|
||||||
|
class classproperty(Generic[T_co]):
|
||||||
|
def __init__(self, fget: Callable[[Any], T_co]) -> None:
|
||||||
|
self.fget = fget
|
||||||
|
|
||||||
|
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
|
||||||
|
return self.fget(owner)
|
||||||
|
|
||||||
|
def __set__(self, instance: Optional[Any], value: Any) -> None:
|
||||||
|
raise AttributeError('cannot set attribute')
|
||||||
|
|
||||||
|
class ModelBase(metaclass=ABCMeta):
|
||||||
#model_path: str
|
#model_path: str
|
||||||
#base_model: BaseModelType
|
#base_model: BaseModelType
|
||||||
#model_type: ModelType
|
#model_type: ModelType
|
||||||
@ -121,7 +137,7 @@ class ModelBase:
|
|||||||
return cls.__configs
|
return cls.__configs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_config(cls, **kwargs):
|
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||||
if "format" not in kwargs:
|
if "format" not in kwargs:
|
||||||
raise Exception("Field 'format' not found in model config")
|
raise Exception("Field 'format' not found in model config")
|
||||||
|
|
||||||
@ -129,16 +145,33 @@ class ModelBase:
|
|||||||
return configs[kwargs["format"]](**kwargs)
|
return configs[kwargs["format"]](**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def probe_config(cls, path: str, **kwargs):
|
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
format=cls.detect_format(path),
|
format=cls.detect_format(path),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
def detect_format(cls, path: str) -> str:
|
def detect_format(cls, path: str) -> str:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
@abstractmethod
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SubModelType] = None,
|
||||||
|
) -> Any:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class DiffusersModel(ModelBase):
|
class DiffusersModel(ModelBase):
|
||||||
|
@ -6,6 +6,7 @@ from .base import (
|
|||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
|
classproperty,
|
||||||
)
|
)
|
||||||
# TODO: naming
|
# TODO: naming
|
||||||
from ..lora import LoRAModel as LoRAModelRaw
|
from ..lora import LoRAModel as LoRAModelRaw
|
||||||
@ -43,7 +44,7 @@ class LoRAModel(ModelBase):
|
|||||||
self.model_size = model.calc_size()
|
self.model_size = model.calc_size()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classproperty
|
||||||
def save_to_config(cls) -> bool:
|
def save_to_config(cls) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import torch
|
|
||||||
import safetensors.torch
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
@ -16,6 +14,7 @@ from .base import (
|
|||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
SilenceWarnings,
|
SilenceWarnings,
|
||||||
read_checkpoint_meta,
|
read_checkpoint_meta,
|
||||||
|
classproperty,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -87,7 +86,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
variant=variant,
|
variant=variant,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classproperty
|
||||||
def save_to_config(cls) -> bool:
|
def save_to_config(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -198,7 +197,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classproperty
|
||||||
def save_to_config(cls) -> bool:
|
def save_to_config(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ from .base import (
|
|||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
|
classproperty,
|
||||||
)
|
)
|
||||||
# TODO: naming
|
# TODO: naming
|
||||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||||
@ -43,7 +44,7 @@ class TextualInversionModel(ModelBase):
|
|||||||
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classproperty
|
||||||
def save_to_config(cls) -> bool:
|
def save_to_config(cls) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ from .base import (
|
|||||||
EmptyConfigLoader,
|
EmptyConfigLoader,
|
||||||
calc_model_size_by_fs,
|
calc_model_size_by_fs,
|
||||||
calc_model_size_by_data,
|
calc_model_size_by_data,
|
||||||
|
classproperty,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from diffusers.utils import is_safetensors_available
|
from diffusers.utils import is_safetensors_available
|
||||||
@ -62,7 +63,7 @@ class VaeModel(ModelBase):
|
|||||||
self.model_size = calc_model_size_by_data(model)
|
self.model_size = calc_model_size_by_data(model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classproperty
|
||||||
def save_to_config(cls) -> bool:
|
def save_to_config(cls) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user