Save models on rescan, uncache model on edit/delete, fixes

This commit is contained in:
Sergey Borisov
2023-06-14 03:12:12 +03:00
parent 26090011c4
commit 740c05a0bb
8 changed files with 88 additions and 299 deletions

View File

@ -47,10 +47,6 @@ class ModelCache(object):
"Forward declaration"
pass
class SDModelInfo(object):
"""Forward declaration"""
pass
class _CacheRecord:
size: int
model: Any
@ -106,7 +102,7 @@ class ModelCache(object):
#max_cache_size = 9999
execution_device = torch.device('cuda')
self.model_infos: Dict[str, SDModelInfo] = dict()
self.model_infos: Dict[str, ModelBase] = dict()
self.lazy_offloading = lazy_offloading
#self.sequential_offload: bool=sequential_offload
self.precision: torch.dtype=precision
@ -225,17 +221,16 @@ class ModelCache(object):
self.cache = cache
self.key = key
self.model = model
self.cache_entry = self.cache._cached_models[self.key]
def __enter__(self) -> Any:
if not hasattr(self.model, 'to'):
return self.model
cache_entry = self.cache._cached_models[self.key]
# NOTE that the model has to have the to() method in order for this
# code to move it into GPU!
if self.gpu_load:
cache_entry.lock()
self.cache_entry.lock()
try:
if self.cache.lazy_offloading:
@ -251,14 +246,14 @@ class ModelCache(object):
self.cache._print_cuda_stats()
except:
cache_entry.unlock()
self.cache_entry.unlock()
raise
# TODO: not fully understand
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
elif cache_entry.loaded and not cache_entry.locked:
elif self.cache_entry.loaded and not self.cache_entry.locked:
self.model.to(self.cache.storage_device)
return self.model
@ -267,12 +262,16 @@ class ModelCache(object):
if not hasattr(self.model, 'to'):
return
cache_entry = self.cache._cached_models[self.key]
cache_entry.unlock()
self.cache_entry.unlock()
if not self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
self.cache._print_cuda_stats()
# TODO: should it be called untrack_model?
def uncache_model(self, cache_id: str):
with suppress(ValueError):
self._cache_stack.remove(cache_id)
self._cached_models.pop(cache_id, None)
def model_hash(
self,

View File

@ -234,38 +234,6 @@ class ModelManager(object):
logger: types.ModuleType = logger
# TODO:
def _convert_2_3_models(self, config: DictConfig):
for model_name, model_config in config.items():
if model_config["format"] == "diffusers":
pass
elif model_config["format"] == "ckpt":
if any(model_config["config"].endswith(file) for file in {
"v1-finetune.yaml",
"v1-finetune_style.yaml",
"v1-inference.yaml",
"v1-inpainting-inference.yaml",
"v1-m1-finetune.yaml",
}):
# copy as as sd1.5
pass
# ~99% accurate should be
elif model_config["config"].endswith("v2-inference-v.yaml"):
# copy as sd 2.x (768)
pass
# for real don't know how accurate it
elif model_config["config"].endswith("v2-inference.yaml"):
# copy as sd 2.x-base (512)
pass
else:
# TODO:
raise Exception("Unknown model")
def __init__(
self,
config: Union[Path, DictConfig, str],
@ -290,11 +258,9 @@ class ModelManager(object):
elif not isinstance(config, DictConfig):
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
#if "__meta__" not in config:
# config = self._convert_2_3_models(config)
config_meta = ConfigMeta(**config.pop("__metadata__")) # TODO: naming
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
# TODO: metadata not found
# TODO: version check
self.models = dict()
for model_key, model_config in config.items():
@ -462,6 +428,10 @@ class ModelManager(object):
submodel=submodel_type,
)
if model_key not in self.cache_keys:
self.cache_keys[model_key] = set()
self.cache_keys[model_key].add(model_context.key)
model_hash = "<NO_HASH>" # TODO:
return ModelInfo(
@ -578,12 +548,12 @@ class ModelManager(object):
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
print(line)
# TODO: test when ui implemented
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
delete_files: bool = False,
):
"""
Delete the named model.
@ -598,29 +568,20 @@ class ModelManager(object):
)
return
# TODO: some legacy?
#if model_name in self.stack:
# self.stack.remove(model_name)
# note: it not garantie to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, [])
for cache_id in cache_ids:
self.cache.uncache_model(cache_id)
if delete_files:
repo_id = model_cfg.get("repo_id", None)
path = self._abs_path(model_cfg.get("path", None))
weights = self._abs_path(model_cfg.get("weights", None))
if "weights" in model_cfg:
weights = self._abs_path(model_cfg["weights"])
self.logger.info(f"Deleting file {weights}")
Path(weights).unlink(missing_ok=True)
elif "path" in model_cfg:
path = self._abs_path(model_cfg["path"])
self.logger.info(f"Deleting directory {path}")
rmtree(path, ignore_errors=True)
elif "repo_id" in model_cfg:
repo_id = model_cfg["repo_id"]
self.logger.info(f"Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id)
# if model inside invoke models folder - delete files
if model_cfg.path.startswith("models/") or model_cfg.path.startswith("models\\"):
model_path = self.globals.root_dir / model_cfg.path
if model_path.isdir():
shutil.rmtree(str(model_path))
else:
model_path.unlink()
# TODO: test when ui implemented
def add_model(
self,
model_name: str,
@ -648,9 +609,10 @@ class ModelManager(object):
self.models[model_key] = model_config
if clobber and model_key in self.cache_keys:
# TODO:
self.cache.uncache_model(self.cache_keys[model_key])
del self.cache_keys[model_key]
# note: it not garantie to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, [])
for cache_id in cache_ids:
self.cache.uncache_model(cache_id)
def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}")
@ -678,6 +640,8 @@ class ModelManager(object):
Write current configuration out to the indicated file.
"""
data_to_save = dict()
data_to_save["__metadata__"] = self.config_meta.dict()
for model_key, model_config in self.models.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
@ -711,46 +675,9 @@ class ModelManager(object):
"""
)
@classmethod
def _delete_model_from_cache(cls,repo_id):
cache_info = scan_cache_dir(InvokeAIAppConfig.get_config().cache_dir)
# I'm sure there is a way to do this with comprehensions
# but the code quickly became incomprehensible!
hashes_to_delete = set()
for repo in cache_info.repos:
if repo.repo_id == repo_id:
for revision in repo.revisions:
hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete)
cls.logger.warning(
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
)
strategy.execute()
@staticmethod
def _abs_path(path: str | Path) -> Path:
globals = InvokeAIAppConfig.get_config()
if path is None or Path(path).is_absolute():
return path
return Path(globals.root_dir, path).resolve()
# This is not the same as global_resolve_path(), which prepends
# Globals.root.
def _resolve_path(
self, source: Union[str, Path], dest_directory: str
) -> Optional[Path]:
resolved_path = None
if str(source).startswith(("http:", "https:", "ftp:")):
dest_directory = self.globals.root_dir / dest_directory
dest_directory.mkdir(parents=True, exist_ok=True)
resolved_path = download_with_resume(str(source), dest_directory)
else:
resolved_path = self.globals.root_dir / source
return resolved_path
def scan_models_directory(self):
loaded_files = set()
new_models_found = False
for model_key, model_config in list(self.models.items()):
model_name, base_model, model_type = self.parse_key(model_key)
@ -783,3 +710,7 @@ class ModelManager(object):
model_config: ModelConfigBase = model_class.probe_config(model_path)
self.models[model_key] = model_config
new_models_found = True
if new_models_found:
self.commit()

View File

@ -3,11 +3,13 @@ import sys
import typing
import inspect
from enum import Enum
from abc import ABCMeta, abstractmethod
import torch
import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any
class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1"
@ -52,6 +54,9 @@ class ModelConfigBase(BaseModel):
# do not save to config
error: Optional[ModelError] = Field(None, exclude=True)
class Config:
use_enum_values = True
class EmptyConfigLoader(ConfigMixin):
@classmethod
@ -59,7 +64,18 @@ class EmptyConfigLoader(ConfigMixin):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
class ModelBase:
T_co = TypeVar('T_co', covariant=True)
class classproperty(Generic[T_co]):
def __init__(self, fget: Callable[[Any], T_co]) -> None:
self.fget = fget
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
return self.fget(owner)
def __set__(self, instance: Optional[Any], value: Any) -> None:
raise AttributeError('cannot set attribute')
class ModelBase(metaclass=ABCMeta):
#model_path: str
#base_model: BaseModelType
#model_type: ModelType
@ -121,7 +137,7 @@ class ModelBase:
return cls.__configs
@classmethod
def create_config(cls, **kwargs):
def create_config(cls, **kwargs) -> ModelConfigBase:
if "format" not in kwargs:
raise Exception("Field 'format' not found in model config")
@ -129,16 +145,33 @@ class ModelBase:
return configs[kwargs["format"]](**kwargs)
@classmethod
def probe_config(cls, path: str, **kwargs):
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config(
path=path,
format=cls.detect_format(path),
)
@classmethod
@abstractmethod
def detect_format(cls, path: str) -> str:
raise NotImplementedError()
@classproperty
@abstractmethod
def save_to_config(cls) -> bool:
raise NotImplementedError()
@abstractmethod
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
raise NotImplementedError()
@abstractmethod
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> Any:
raise NotImplementedError()
class DiffusersModel(ModelBase):

View File

@ -6,6 +6,7 @@ from .base import (
BaseModelType,
ModelType,
SubModelType,
classproperty,
)
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
@ -43,7 +44,7 @@ class LoRAModel(ModelBase):
self.model_size = model.calc_size()
return model
@classmethod
@classproperty
def save_to_config(cls) -> bool:
return False

View File

@ -1,7 +1,5 @@
import os
import json
import torch
import safetensors.torch
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
@ -16,6 +14,7 @@ from .base import (
SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
classproperty,
)
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
@ -87,7 +86,7 @@ class StableDiffusion1Model(DiffusersModel):
variant=variant,
)
@classmethod
@classproperty
def save_to_config(cls) -> bool:
return True
@ -198,7 +197,7 @@ class StableDiffusion2Model(DiffusersModel):
upcast_attention=upcast_attention,
)
@classmethod
@classproperty
def save_to_config(cls) -> bool:
return True

View File

@ -6,6 +6,7 @@ from .base import (
BaseModelType,
ModelType,
SubModelType,
classproperty,
)
# TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw
@ -43,7 +44,7 @@ class TextualInversionModel(ModelBase):
self.model_size = model.embedding.nelement() * model.embedding.element_size()
return model
@classmethod
@classproperty
def save_to_config(cls) -> bool:
return False

View File

@ -12,6 +12,7 @@ from .base import (
EmptyConfigLoader,
calc_model_size_by_fs,
calc_model_size_by_data,
classproperty,
)
from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
@ -62,7 +63,7 @@ class VaeModel(ModelBase):
self.model_size = calc_model_size_by_data(model)
return model
@classmethod
@classproperty
def save_to_config(cls) -> bool:
return False