mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Save models on rescan, uncache model on edit/delete, fixes
This commit is contained in:
parent
26090011c4
commit
740c05a0bb
@ -140,7 +140,6 @@ class ModelManagerServiceBase(ABC):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
@ -149,91 +148,6 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def import_diffuser_model(
|
||||
repo_or_path: Union[str, Path],
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
vae: Optional[dict] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Install the indicated diffuser model and returns True if successful.
|
||||
|
||||
"repo_or_path" can be either a repo-id or a path-like object corresponding to the
|
||||
top of a downloaded diffusers directory.
|
||||
|
||||
You can optionally provide a model name and/or description. If not provided,
|
||||
then these will be derived from the repo name. Call commit() to write to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def import_lora(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
mgr.commit() to write out the configuration to models.yaml
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def import_embedding(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated textual inversion embedding file.
|
||||
Call commit() to write out the configuration to models.yaml
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path], Path] = None,
|
||||
) -> str:
|
||||
"""Accept a string which could be:
|
||||
- a HF diffusers repo_id
|
||||
- a URL pointing to a legacy .ckpt or .safetensors file
|
||||
- a local path pointing to a legacy .ckpt or .safetensors file
|
||||
- a local directory containing .ckpt and .safetensors files
|
||||
- a local directory containing a diffusers model
|
||||
|
||||
After determining the nature of the model and downloading it
|
||||
(if necessary), the file is probed to determine the correct
|
||||
configuration file (if needed) and it is imported.
|
||||
|
||||
The model_name and/or description can be provided. If not, they will
|
||||
be generated automatically.
|
||||
|
||||
If commit_to_conf is provided, the newly loaded model will be written
|
||||
to the `models.yaml` file at the indicated path. Otherwise, the changes
|
||||
will only remain in memory.
|
||||
|
||||
The routine will do its best to figure out the config file
|
||||
needed to convert legacy checkpoint file, but if it can't it
|
||||
will call the config_file_callback routine, if provided. The
|
||||
callback accepts a single argument, the Path to the checkpoint
|
||||
file, and returns a Path to the config file to use.
|
||||
|
||||
The (potentially derived) name of the model is returned on
|
||||
success, or None on failure. When multiple models are added
|
||||
from a directory, only the last imported one is returned.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Path = None) -> None:
|
||||
"""
|
||||
@ -424,103 +338,13 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
self.mgr.del_model(model_name, base_model, model_type, delete_files)
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
repo_or_path: Union[str, Path],
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
vae: Optional[dict] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Install the indicated diffuser model and returns True if successful.
|
||||
|
||||
"repo_or_path" can be either a repo-id or a path-like object corresponding to the
|
||||
top of a downloaded diffusers directory.
|
||||
|
||||
You can optionally provide a model name and/or description. If not provided,
|
||||
then these will be derived from the repo name. Call commit() to write to disk.
|
||||
"""
|
||||
return self.mgr.import_diffuser_model(repo_or_path, model_name, description, vae)
|
||||
|
||||
def import_lora(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
mgr.commit() to write out the configuration to models.yaml
|
||||
"""
|
||||
self.mgr.import_lora(path, model_name, description)
|
||||
|
||||
def import_embedding(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated textual inversion embedding file.
|
||||
Call commit() to write out the configuration to models.yaml
|
||||
"""
|
||||
self.mgr.import_embedding(path, model_name, description)
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Optional[Path] = None,
|
||||
commit_to_conf: Optional[Path] = None,
|
||||
config_file_callback: Optional[Callable[[Path], Path]] = None,
|
||||
) -> str:
|
||||
"""Accept a string which could be:
|
||||
- a HF diffusers repo_id
|
||||
- a URL pointing to a legacy .ckpt or .safetensors file
|
||||
- a local path pointing to a legacy .ckpt or .safetensors file
|
||||
- a local directory containing .ckpt and .safetensors files
|
||||
- a local directory containing a diffusers model
|
||||
|
||||
After determining the nature of the model and downloading it
|
||||
(if necessary), the file is probed to determine the correct
|
||||
configuration file (if needed) and it is imported.
|
||||
|
||||
The model_name and/or description can be provided. If not, they will
|
||||
be generated automatically.
|
||||
|
||||
If commit_to_conf is provided, the newly loaded model will be written
|
||||
to the `models.yaml` file at the indicated path. Otherwise, the changes
|
||||
will only remain in memory.
|
||||
|
||||
The routine will do its best to figure out the config file
|
||||
needed to convert legacy checkpoint file, but if it can't it
|
||||
will call the config_file_callback routine, if provided. The
|
||||
callback accepts a single argument, the Path to the checkpoint
|
||||
file, and returns a Path to the config file to use.
|
||||
|
||||
The (potentially derived) name of the model is returned on
|
||||
success, or None on failure. When multiple models are added
|
||||
from a directory, only the last imported one is returned.
|
||||
|
||||
"""
|
||||
return self.mgr.heuristic_import(
|
||||
path_url_or_repo,
|
||||
model_name,
|
||||
description,
|
||||
model_config_file,
|
||||
commit_to_conf,
|
||||
config_file_callback
|
||||
)
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
|
||||
|
||||
def commit(self, conf_file: Optional[Path]=None):
|
||||
|
@ -47,10 +47,6 @@ class ModelCache(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
|
||||
class SDModelInfo(object):
|
||||
"""Forward declaration"""
|
||||
pass
|
||||
|
||||
class _CacheRecord:
|
||||
size: int
|
||||
model: Any
|
||||
@ -106,7 +102,7 @@ class ModelCache(object):
|
||||
#max_cache_size = 9999
|
||||
execution_device = torch.device('cuda')
|
||||
|
||||
self.model_infos: Dict[str, SDModelInfo] = dict()
|
||||
self.model_infos: Dict[str, ModelBase] = dict()
|
||||
self.lazy_offloading = lazy_offloading
|
||||
#self.sequential_offload: bool=sequential_offload
|
||||
self.precision: torch.dtype=precision
|
||||
@ -225,17 +221,16 @@ class ModelCache(object):
|
||||
self.cache = cache
|
||||
self.key = key
|
||||
self.model = model
|
||||
self.cache_entry = self.cache._cached_models[self.key]
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
if not hasattr(self.model, 'to'):
|
||||
return self.model
|
||||
|
||||
cache_entry = self.cache._cached_models[self.key]
|
||||
|
||||
# NOTE that the model has to have the to() method in order for this
|
||||
# code to move it into GPU!
|
||||
if self.gpu_load:
|
||||
cache_entry.lock()
|
||||
self.cache_entry.lock()
|
||||
|
||||
try:
|
||||
if self.cache.lazy_offloading:
|
||||
@ -251,14 +246,14 @@ class ModelCache(object):
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
except:
|
||||
cache_entry.unlock()
|
||||
self.cache_entry.unlock()
|
||||
raise
|
||||
|
||||
|
||||
# TODO: not fully understand
|
||||
# in the event that the caller wants the model in RAM, we
|
||||
# move it into CPU if it is in GPU and not locked
|
||||
elif cache_entry.loaded and not cache_entry.locked:
|
||||
elif self.cache_entry.loaded and not self.cache_entry.locked:
|
||||
self.model.to(self.cache.storage_device)
|
||||
|
||||
return self.model
|
||||
@ -267,12 +262,16 @@ class ModelCache(object):
|
||||
if not hasattr(self.model, 'to'):
|
||||
return
|
||||
|
||||
cache_entry = self.cache._cached_models[self.key]
|
||||
cache_entry.unlock()
|
||||
self.cache_entry.unlock()
|
||||
if not self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models()
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
# TODO: should it be called untrack_model?
|
||||
def uncache_model(self, cache_id: str):
|
||||
with suppress(ValueError):
|
||||
self._cache_stack.remove(cache_id)
|
||||
self._cached_models.pop(cache_id, None)
|
||||
|
||||
def model_hash(
|
||||
self,
|
||||
|
@ -234,38 +234,6 @@ class ModelManager(object):
|
||||
|
||||
logger: types.ModuleType = logger
|
||||
|
||||
# TODO:
|
||||
def _convert_2_3_models(self, config: DictConfig):
|
||||
for model_name, model_config in config.items():
|
||||
if model_config["format"] == "diffusers":
|
||||
pass
|
||||
elif model_config["format"] == "ckpt":
|
||||
|
||||
if any(model_config["config"].endswith(file) for file in {
|
||||
"v1-finetune.yaml",
|
||||
"v1-finetune_style.yaml",
|
||||
"v1-inference.yaml",
|
||||
"v1-inpainting-inference.yaml",
|
||||
"v1-m1-finetune.yaml",
|
||||
}):
|
||||
# copy as as sd1.5
|
||||
pass
|
||||
|
||||
# ~99% accurate should be
|
||||
elif model_config["config"].endswith("v2-inference-v.yaml"):
|
||||
# copy as sd 2.x (768)
|
||||
pass
|
||||
|
||||
# for real don't know how accurate it
|
||||
elif model_config["config"].endswith("v2-inference.yaml"):
|
||||
# copy as sd 2.x-base (512)
|
||||
pass
|
||||
|
||||
else:
|
||||
# TODO:
|
||||
raise Exception("Unknown model")
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Union[Path, DictConfig, str],
|
||||
@ -290,11 +258,9 @@ class ModelManager(object):
|
||||
elif not isinstance(config, DictConfig):
|
||||
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
||||
|
||||
#if "__meta__" not in config:
|
||||
# config = self._convert_2_3_models(config)
|
||||
|
||||
config_meta = ConfigMeta(**config.pop("__metadata__")) # TODO: naming
|
||||
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
|
||||
# TODO: metadata not found
|
||||
# TODO: version check
|
||||
|
||||
self.models = dict()
|
||||
for model_key, model_config in config.items():
|
||||
@ -462,6 +428,10 @@ class ModelManager(object):
|
||||
submodel=submodel_type,
|
||||
)
|
||||
|
||||
if model_key not in self.cache_keys:
|
||||
self.cache_keys[model_key] = set()
|
||||
self.cache_keys[model_key].add(model_context.key)
|
||||
|
||||
model_hash = "<NO_HASH>" # TODO:
|
||||
|
||||
return ModelInfo(
|
||||
@ -578,12 +548,12 @@ class ModelManager(object):
|
||||
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||
print(line)
|
||||
|
||||
# TODO: test when ui implemented
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model.
|
||||
@ -598,29 +568,20 @@ class ModelManager(object):
|
||||
)
|
||||
return
|
||||
|
||||
# TODO: some legacy?
|
||||
#if model_name in self.stack:
|
||||
# self.stack.remove(model_name)
|
||||
# note: it not garantie to release memory(model can has other references)
|
||||
cache_ids = self.cache_keys.pop(model_key, [])
|
||||
for cache_id in cache_ids:
|
||||
self.cache.uncache_model(cache_id)
|
||||
|
||||
if delete_files:
|
||||
repo_id = model_cfg.get("repo_id", None)
|
||||
path = self._abs_path(model_cfg.get("path", None))
|
||||
weights = self._abs_path(model_cfg.get("weights", None))
|
||||
if "weights" in model_cfg:
|
||||
weights = self._abs_path(model_cfg["weights"])
|
||||
self.logger.info(f"Deleting file {weights}")
|
||||
Path(weights).unlink(missing_ok=True)
|
||||
|
||||
elif "path" in model_cfg:
|
||||
path = self._abs_path(model_cfg["path"])
|
||||
self.logger.info(f"Deleting directory {path}")
|
||||
rmtree(path, ignore_errors=True)
|
||||
|
||||
elif "repo_id" in model_cfg:
|
||||
repo_id = model_cfg["repo_id"]
|
||||
self.logger.info(f"Deleting the cached model directory for {repo_id}")
|
||||
self._delete_model_from_cache(repo_id)
|
||||
# if model inside invoke models folder - delete files
|
||||
if model_cfg.path.startswith("models/") or model_cfg.path.startswith("models\\"):
|
||||
model_path = self.globals.root_dir / model_cfg.path
|
||||
if model_path.isdir():
|
||||
shutil.rmtree(str(model_path))
|
||||
else:
|
||||
model_path.unlink()
|
||||
|
||||
# TODO: test when ui implemented
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -648,9 +609,10 @@ class ModelManager(object):
|
||||
self.models[model_key] = model_config
|
||||
|
||||
if clobber and model_key in self.cache_keys:
|
||||
# TODO:
|
||||
self.cache.uncache_model(self.cache_keys[model_key])
|
||||
del self.cache_keys[model_key]
|
||||
# note: it not garantie to release memory(model can has other references)
|
||||
cache_ids = self.cache_keys.pop(model_key, [])
|
||||
for cache_id in cache_ids:
|
||||
self.cache.uncache_model(cache_id)
|
||||
|
||||
def search_models(self, search_folder):
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
@ -678,6 +640,8 @@ class ModelManager(object):
|
||||
Write current configuration out to the indicated file.
|
||||
"""
|
||||
data_to_save = dict()
|
||||
data_to_save["__metadata__"] = self.config_meta.dict()
|
||||
|
||||
for model_key, model_config in self.models.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
@ -711,46 +675,9 @@ class ModelManager(object):
|
||||
"""
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _delete_model_from_cache(cls,repo_id):
|
||||
cache_info = scan_cache_dir(InvokeAIAppConfig.get_config().cache_dir)
|
||||
|
||||
# I'm sure there is a way to do this with comprehensions
|
||||
# but the code quickly became incomprehensible!
|
||||
hashes_to_delete = set()
|
||||
for repo in cache_info.repos:
|
||||
if repo.repo_id == repo_id:
|
||||
for revision in repo.revisions:
|
||||
hashes_to_delete.add(revision.commit_hash)
|
||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||
cls.logger.warning(
|
||||
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||
)
|
||||
strategy.execute()
|
||||
|
||||
@staticmethod
|
||||
def _abs_path(path: str | Path) -> Path:
|
||||
globals = InvokeAIAppConfig.get_config()
|
||||
if path is None or Path(path).is_absolute():
|
||||
return path
|
||||
return Path(globals.root_dir, path).resolve()
|
||||
|
||||
# This is not the same as global_resolve_path(), which prepends
|
||||
# Globals.root.
|
||||
def _resolve_path(
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
) -> Optional[Path]:
|
||||
resolved_path = None
|
||||
if str(source).startswith(("http:", "https:", "ftp:")):
|
||||
dest_directory = self.globals.root_dir / dest_directory
|
||||
dest_directory.mkdir(parents=True, exist_ok=True)
|
||||
resolved_path = download_with_resume(str(source), dest_directory)
|
||||
else:
|
||||
resolved_path = self.globals.root_dir / source
|
||||
return resolved_path
|
||||
|
||||
def scan_models_directory(self):
|
||||
loaded_files = set()
|
||||
new_models_found = False
|
||||
|
||||
for model_key, model_config in list(self.models.items()):
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
@ -783,3 +710,7 @@ class ModelManager(object):
|
||||
|
||||
model_config: ModelConfigBase = model_class.probe_config(model_path)
|
||||
self.models[model_key] = model_config
|
||||
new_models_found = True
|
||||
|
||||
if new_models_found:
|
||||
self.commit()
|
||||
|
@ -3,11 +3,13 @@ import sys
|
||||
import typing
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
import safetensors.torch
|
||||
from diffusers import DiffusionPipeline, ConfigMixin
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
StableDiffusion1 = "sd-1"
|
||||
@ -52,6 +54,9 @@ class ModelConfigBase(BaseModel):
|
||||
# do not save to config
|
||||
error: Optional[ModelError] = Field(None, exclude=True)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class EmptyConfigLoader(ConfigMixin):
|
||||
@classmethod
|
||||
@ -59,7 +64,18 @@ class EmptyConfigLoader(ConfigMixin):
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
return super().load_config(*args, **kwargs)
|
||||
|
||||
class ModelBase:
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
class classproperty(Generic[T_co]):
|
||||
def __init__(self, fget: Callable[[Any], T_co]) -> None:
|
||||
self.fget = fget
|
||||
|
||||
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
|
||||
return self.fget(owner)
|
||||
|
||||
def __set__(self, instance: Optional[Any], value: Any) -> None:
|
||||
raise AttributeError('cannot set attribute')
|
||||
|
||||
class ModelBase(metaclass=ABCMeta):
|
||||
#model_path: str
|
||||
#base_model: BaseModelType
|
||||
#model_type: ModelType
|
||||
@ -121,7 +137,7 @@ class ModelBase:
|
||||
return cls.__configs
|
||||
|
||||
@classmethod
|
||||
def create_config(cls, **kwargs):
|
||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||
if "format" not in kwargs:
|
||||
raise Exception("Field 'format' not found in model config")
|
||||
|
||||
@ -129,16 +145,33 @@ class ModelBase:
|
||||
return configs[kwargs["format"]](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=cls.detect_format(path),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classproperty
|
||||
@abstractmethod
|
||||
def save_to_config(cls) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiffusersModel(ModelBase):
|
||||
|
@ -6,6 +6,7 @@ from .base import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import LoRAModel as LoRAModelRaw
|
||||
@ -43,7 +44,7 @@ class LoRAModel(ModelBase):
|
||||
self.model_size = model.calc_size()
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import safetensors.torch
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
@ -16,6 +14,7 @@ from .base import (
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from omegaconf import OmegaConf
|
||||
@ -87,7 +86,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@ -198,7 +197,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
@ -6,6 +6,7 @@ from .base import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
@ -43,7 +44,7 @@ class TextualInversionModel(ModelBase):
|
||||
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
|
@ -12,6 +12,7 @@ from .base import (
|
||||
EmptyConfigLoader,
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
@ -62,7 +63,7 @@ class VaeModel(ModelBase):
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user