Files
InvokeAI/invokeai/backend/model_manager/storage/yaml.py
2023-09-29 19:23:08 -04:00

263 lines
8.8 KiB
Python

# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Implementation of ModelConfigStore using a YAML file.
Typical usage:
from invokeai.backend.model_manager.storage.yaml import ModelConfigStoreYAML
store = ModelConfigStoreYAML("./configs/models.yaml")
config = dict(
path='/tmp/pokemon.bin',
name='old name',
base_model='sd-1',
model_type='embedding',
model_format='embedding_file',
author='Anonymous',
tags=['sfw','cartoon']
)
# adding - the key becomes the model's "key" field
store.add_model('key1', config)
# updating
config.name='new name'
store.update_model('key1', config)
# checking for existence
if store.exists('key1'):
print("yes")
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base_model)
assert new_config.key == 'key1'
# deleting
store.del_model('key1')
# searching
configs = store.search_by_tag({'sfw','oss license'})
configs = store.search_by_name(base_model='sd-2', model_type='main')
"""
import threading
from enum import Enum
from pathlib import Path
from typing import List, Optional, Set, Union
import yaml
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
from .base import (
CONFIG_FILE_VERSION,
ConfigFileVersionMismatchException,
DuplicateModelException,
ModelConfigStore,
UnknownModelException,
)
class ModelConfigStoreYAML(ModelConfigStore):
"""Implementation of the ModelConfigStore ABC using a YAML file."""
_filename: Path
_config: DictConfig
_lock: threading.RLock
def __init__(self, config_file: Path):
"""Initialize ModelConfigStore object with a .yaml file."""
super().__init__()
self._filename = Path(config_file).absolute() # don't let chdir mess us up!
self._lock = threading.RLock()
if not self._filename.exists():
self._initialize_yaml()
config = OmegaConf.load(self._filename)
assert isinstance(config, DictConfig)
self._config = config
if str(self.version) != CONFIG_FILE_VERSION:
raise ConfigFileVersionMismatchException
def _initialize_yaml(self):
try:
self._lock.acquire()
self._filename.parent.mkdir(parents=True, exist_ok=True)
with open(self._filename, "w") as yaml_file:
yaml_file.write(yaml.dump({"__metadata__": {"version": CONFIG_FILE_VERSION}}))
finally:
self._lock.release()
def _commit(self):
try:
self._lock.acquire()
newfile = Path(str(self._filename) + ".new")
yaml_str = OmegaConf.to_yaml(self._config)
with open(newfile, "w", encoding="utf-8") as outfile:
outfile.write(yaml_str)
newfile.replace(self._filename)
finally:
self._lock.release()
@property
def version(self) -> str:
"""Return version of this config file/database."""
return self._config.__metadata__.get("version")
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfig exceptions.
"""
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
dict_fields = record.dict() # and back to a dict with valid fields
try:
self._lock.acquire()
if key in self._config:
existing_model = self.get_model(key)
raise DuplicateModelException(
f"Can't save {record.name} because a model named '{existing_model.name}' is already stored with the same key '{key}'"
)
self._config[key] = self._fix_enums(dict_fields)
self._commit()
finally:
self._lock.release()
def _fix_enums(self, original: dict) -> dict:
"""In python 3.9, omegaconf stores incorrectly stringified enums."""
fixed_dict = {}
for key, value in original.items():
fixed_dict[key] = value.value if isinstance(value, Enum) else value
return fixed_dict
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
try:
self._lock.acquire()
if key not in self._config:
raise UnknownModelException(f"Unknown key '{key}' for model config")
self._config.pop(key)
self._commit()
finally:
self._lock.release()
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
dict_fields = record.dict() # and back to a dict with valid fields
try:
self._lock.acquire()
if key not in self._config:
raise UnknownModelException(f"Unknown key '{key}' for model config")
self._config[key] = self._fix_enums(dict_fields)
self._commit()
finally:
self._lock.release()
return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
try:
record = self._config[key]
return ModelConfigFactory.make_config(record, key)
except KeyError as e:
raise UnknownModelException(f"Unknown key '{key}' for model config") from e
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
return key in self._config
def search_by_tag(self, tags: Set[str]) -> List[ModelConfigBase]:
"""
Return models containing all of the listed tags.
:param tags: Set of tags to search on.
"""
results = []
tags = set(tags)
try:
self._lock.acquire()
for config in self.all_models():
config_tags = set(config.tags or [])
if tags.difference(config_tags): # not all tags in the model
continue
results.append(config)
finally:
self._lock.release()
return results
def search_by_name(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[ModelConfigBase]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
results: List[ModelConfigBase] = list()
try:
self._lock.acquire()
for key, record in self._config.items():
if key == "__metadata__":
continue
model = ModelConfigFactory.make_config(record, str(key))
if model_name and model.name != model_name:
continue
if base_model and model.base_model != base_model:
continue
if model_type and model.model_type != model_type:
continue
results.append(model)
finally:
self._lock.release()
return results
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
"""Return the model with the indicated path, or None."""
try:
self._lock.acquire()
for key, record in self._config.items():
if key == "__metadata__":
continue
model = ModelConfigFactory.make_config(record, str(key))
if model.path == path:
return model
finally:
self._lock.release()
return None