remove unused code from invokeai.backend.model_manager.storage.yaml

This commit is contained in:
Lincoln Stein
2023-09-29 01:07:18 -04:00
parent 3b832f1db2
commit 4555aec17c

View File

@ -48,6 +48,7 @@ from typing import List, Optional, Set, Union
import yaml
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from omegaconf.listconfig import ListConfig
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
from .base import (
@ -63,8 +64,8 @@ class ModelConfigStoreYAML(ModelConfigStore):
"""Implementation of the ModelConfigStore ABC using a YAML file."""
_filename: Path
_config: DictConfig
_lock: threading.Lock
_config: Union[DictConfig, ListConfig]
_lock: threading.RLock
def __init__(self, config_file: Path):
"""Initialize ModelConfigStore object with a .yaml file."""
@ -127,7 +128,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
self._lock.release()
def _fix_enums(self, original: dict) -> dict:
"""In python 3.9, omegaconf stores incorrectly stringified enums"""
"""In python 3.9, omegaconf stores incorrectly stringified enums."""
fixed_dict = {}
for key, value in original.items():
fixed_dict[key] = value.value if isinstance(value, Enum) else value
@ -203,7 +204,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
try:
self._lock.acquire()
for config in self.all_models():
config_tags = set(config.tags)
config_tags = set(config.tags or [])
if tags.difference(config_tags): # not all tags in the model
continue
results.append(config)
@ -227,13 +228,13 @@ class ModelConfigStoreYAML(ModelConfigStore):
If none of the optional filters are passed, will return all
models in the database.
"""
results = []
results: List[ModelConfigBase] = list()
try:
self._lock.acquire()
for key, record in self._config.items():
if key == "__metadata__":
continue
model = ModelConfigFactory.make_config(record, key)
model = ModelConfigFactory.make_config(record, str(key))
if model_name and model.name != model_name:
continue
if base_model and model.base_model != base_model:
@ -246,65 +247,15 @@ class ModelConfigStoreYAML(ModelConfigStore):
return results
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
"""
Return the model with the indicated path, or None..
"""
"""Return the model with the indicated path, or None."""
try:
self._lock.acquire()
for key, record in self._config.items():
if key == "__metadata__":
continue
model = ModelConfigFactory.make_config(record, key)
model = ModelConfigFactory.make_config(record, str(key))
if model.path == path:
return model
finally:
self._lock.release()
return None
def _load_and_maybe_upgrade(self, config_path: Path) -> DictConfig:
config = OmegaConf.load(config_path)
version = config["__metadata__"].get("version")
if version == CONFIG_FILE_VERSION:
return config
# if we get here we need to upgrade
if version == "3.0.0":
return self._migrate_format_to_3_2(config, config_path)
else:
raise Exception(f"{config_path} has unknown version: {version}")
def _migrate_format_to_3_2(self, old_config: DictConfig, config_path: Path) -> DictConfig:
print(
f"** Doing one-time conversion of {config_path.as_posix()} to new format. Original will be named {config_path.as_posix() + '.orig'}"
)
# avoid circular dependencies
from shutil import move
from ..install import InvalidModelException, ModelInstall
move(config_path, config_path.as_posix() + ".orig")
new_store = self.__class__(config_path)
installer = ModelInstall(store=new_store)
for model_key, stanza in old_config.items():
if model_key == "__metadata__":
assert (
stanza["version"] == "3.0.0"
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
try:
path = stanza["path"]
new_key = installer.register_path(path)
model_info = new_store.get_model(new_key)
if vae := stanza.get("vae"):
model_info.vae = vae
if model_config := stanza.get("config"):
model_info.config = model_config.as_posix()
model_info.description = stanza.get("description")
new_store.update_model(new_key, model_info)
return OmegaConf.load(config_path)
except (DuplicateModelException, InvalidModelException) as e:
print(str(e))