mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
blackify
This commit is contained in:
@ -18,7 +18,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
UnknownModelException,
|
||||
DuplicateModelException
|
||||
DuplicateModelException,
|
||||
)
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any
|
||||
@ -51,10 +51,10 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model identified by key.
|
||||
|
||||
@ -71,8 +71,8 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def model_exists(
|
||||
self,
|
||||
key: str,
|
||||
self,
|
||||
key: str,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@ -85,10 +85,11 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
def list_models(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[ModelConfigBase]:
|
||||
"""
|
||||
Return a list of ModelConfigBases that match the base, type and name criteria.
|
||||
@ -104,13 +105,11 @@ class ModelManagerServiceBase(ABC):
|
||||
If there are more than one model that match, raises a DuplicateModelException.
|
||||
If no model matches, raises an UnknownModelException
|
||||
"""
|
||||
model_configs = self.list_models(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type
|
||||
)
|
||||
model_configs = self.list_models(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
if len(model_configs) > 1:
|
||||
raise DuplicateModelException("More than one model share the same name and type: {base_model}/{model_type}/{model_name}")
|
||||
raise DuplicateModelException(
|
||||
"More than one model share the same name and type: {base_model}/{model_type}/{model_name}"
|
||||
)
|
||||
if len(model_configs) == 0:
|
||||
raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}")
|
||||
return model_configs[0]
|
||||
@ -123,22 +122,19 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
probe_overrides: Optional[Dict[str, Any]] = None,
|
||||
wait: bool = False
|
||||
self, model_path: Path, probe_overrides: Optional[Dict[str, Any]] = None, wait: bool = False
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Add a model using its path, with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists.
|
||||
assertion error if the name already exists.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(
|
||||
self,
|
||||
key: str,
|
||||
new_config: Union[dict, ModelConfigBase],
|
||||
self,
|
||||
key: str,
|
||||
new_config: Union[dict, ModelConfigBase],
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
@ -151,11 +147,7 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
key: str,
|
||||
delete_files: bool = False
|
||||
):
|
||||
def del_model(self, key: str, delete_files: bool = False):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files
|
||||
is true, then the underlying file or directory will be
|
||||
@ -164,9 +156,9 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Rename the indicated model.
|
||||
@ -182,14 +174,14 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
convert_dest_directory: Path,
|
||||
self,
|
||||
key: str,
|
||||
convert_dest_directory: Path,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
This will delete the cached version if there is any and delete the original
|
||||
This will delete the cached version if there is any and delete the original
|
||||
checkpoint file if it is in the models directory.
|
||||
:param key: Unique key for the model to convert.
|
||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||
@ -201,10 +193,10 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install_model (
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
|
||||
|
||||
@ -270,7 +262,7 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
|
||||
# implementation
|
||||
# implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
@ -289,18 +281,18 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self._loader = ModelLoader(config)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
"""
|
||||
|
||||
|
||||
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
|
||||
|
||||
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
@ -313,8 +305,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
return model_info
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
key: str,
|
||||
self,
|
||||
key: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a model key, returns True if it is a valid
|
||||
@ -331,10 +323,11 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
# def all_models(self) -> List[ModelConfigBase] -- defined in base class, same as list_models()
|
||||
# def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -- defined in base class
|
||||
|
||||
def list_models(self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
def list_models(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[ModelConfigBase]:
|
||||
"""
|
||||
Return a ModelConfigBase object for each model in the database.
|
||||
@ -347,14 +340,11 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_attributes: Optional[dict] = None,
|
||||
wait: bool = False
|
||||
self, model_path: Path, model_attributes: Optional[dict] = None, wait: bool = False
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Add a model using its path, with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists.
|
||||
assertion error if the name already exists.
|
||||
"""
|
||||
self.logger.debug(f"add/update model {model_path}")
|
||||
return self._loader.installer.install(
|
||||
@ -363,16 +353,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
)
|
||||
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Add a model using its path, with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists.
|
||||
assertion error if the name already exists.
|
||||
"""
|
||||
self.logger.debug(f"add/update model {source}")
|
||||
variant = 'fp16' if self._loader.precision == 'float16' else None
|
||||
variant = "fp16" if self._loader.precision == "float16" else None
|
||||
return self._loader.installer.install(
|
||||
source,
|
||||
probe_override=model_attributes,
|
||||
@ -380,9 +370,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
)
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
key: str,
|
||||
new_config: Union[dict, ModelConfigBase],
|
||||
self,
|
||||
key: str,
|
||||
new_config: Union[dict, ModelConfigBase],
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
@ -398,9 +388,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
return self._loader.store.update_model(key, new_config)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
key: str,
|
||||
delete_files: bool = False,
|
||||
self,
|
||||
key: str,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
@ -418,9 +408,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
path.unlink()
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
convert_dest_directory: Path,
|
||||
self,
|
||||
key: str,
|
||||
convert_dest_directory: Path,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
@ -436,7 +426,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
model_info = self.model_info(key)
|
||||
self.logger.debug(f"convert model {model_info.name}")
|
||||
self.logger.warning('This is not implemented yet')
|
||||
self.logger.warning("This is not implemented yet")
|
||||
return self._loader.convert_model(key, convert_dest_directory)
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
@ -478,17 +468,17 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
@property
|
||||
def logger(self):
|
||||
return self._loader.logger
|
||||
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
@ -500,7 +490,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
try:
|
||||
self.logger.error('ModelMerger needs to be rewritten.')
|
||||
self.logger.error("ModelMerger needs to be rewritten.")
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_keys=model_keys,
|
||||
merged_model_name=merged_model_name,
|
||||
|
@ -172,11 +172,13 @@ class VaeDiffusersConfig(ModelConfigBase):
|
||||
|
||||
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class TextualInversionConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
|
@ -344,14 +344,12 @@ class DownloadQueue(DownloadQueueBase):
|
||||
if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url):
|
||||
version = match.group(1)
|
||||
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
|
||||
metadata.thumbnail_url = metadata.thumbnail_url \
|
||||
or resp["images"][0]["url"]
|
||||
metadata.description = metadata.description \
|
||||
or (
|
||||
f"Trigger terms: {(', ').join(resp['trainedWords'])}"
|
||||
if resp["trainedWords"]
|
||||
else resp["description"]
|
||||
)
|
||||
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
|
||||
metadata.description = metadata.description or (
|
||||
f"Trigger terms: {(', ').join(resp['trainedWords'])}"
|
||||
if resp["trainedWords"]
|
||||
else resp["description"]
|
||||
)
|
||||
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"])
|
||||
|
||||
# a Civitai model page
|
||||
@ -364,10 +362,11 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
metadata.author = metadata.author or resp["creator"]["username"]
|
||||
metadata.tags = metadata.tags or resp["tags"]
|
||||
metadata.thumbnail_url = metadata.thumbnail_url \
|
||||
or resp["modelVersions"][0]["images"][0]["url"]
|
||||
metadata.license = metadata.license \
|
||||
metadata.thumbnail_url = metadata.thumbnail_url or resp["modelVersions"][0]["images"][0]["url"]
|
||||
metadata.license = (
|
||||
metadata.license
|
||||
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
|
||||
)
|
||||
except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp:
|
||||
self._logger.warn(excp)
|
||||
|
||||
|
@ -316,9 +316,8 @@ class ModelInstall(ModelInstallBase):
|
||||
"""Return the queue."""
|
||||
return self._download_queue
|
||||
|
||||
def register_path(self,
|
||||
model_path: Union[Path, str],
|
||||
overrides: Optional[Dict[str, Any]] = None
|
||||
def register_path(
|
||||
self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
@ -354,13 +353,13 @@ class ModelInstall(ModelInstallBase):
|
||||
return key
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
overrides: Optional[Dict[str, Any]] = None,
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
overrides: Optional[Dict[str, Any]] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
|
||||
|
||||
dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -375,14 +374,11 @@ class ModelInstall(ModelInstallBase):
|
||||
info,
|
||||
)
|
||||
|
||||
def _probe_model(self,
|
||||
model_path: Union[Path, str],
|
||||
overrides: Optional[Dict[str,Any]] = None
|
||||
) -> ModelProbeInfo:
|
||||
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
|
||||
info: ModelProbeInfo = ModelProbe.probe(model_path)
|
||||
if overrides: # used to override probe fields
|
||||
for key, value in overrides.items():
|
||||
setattr(info, key, value) # may generate a pydantic validation error
|
||||
setattr(info, key, value) # may generate a pydantic validation error
|
||||
return info
|
||||
|
||||
def unregister(self, key: str): # noqa D102
|
||||
@ -394,12 +390,12 @@ class ModelInstall(ModelInstallBase):
|
||||
self.unregister(key)
|
||||
|
||||
def install(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
queue = self._download_queue
|
||||
|
||||
|
@ -75,13 +75,10 @@ class ModelLoaderBase(ABC):
|
||||
"""Return the current logger."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(
|
||||
self,
|
||||
cache_stats: CacheStats
|
||||
):
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""Replace cache statistics."""
|
||||
|
||||
pass
|
||||
|
||||
@property
|
||||
@ -98,6 +95,7 @@ class ModelLoaderBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelLoader(ModelLoaderBase):
|
||||
"""Implementation of ModelLoaderBase."""
|
||||
|
||||
@ -228,10 +226,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
_cache=self._cache,
|
||||
)
|
||||
|
||||
def collect_cache_stats(
|
||||
self,
|
||||
cache_stats: CacheStats
|
||||
):
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
self._cache.stats = cache_stats
|
||||
|
||||
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
||||
@ -275,7 +270,6 @@ class ModelLoader(ModelLoaderBase):
|
||||
installed = set()
|
||||
|
||||
with Chdir(self._app_config.models_path):
|
||||
|
||||
self._logger.info("Checking for models that have been moved or deleted from disk.")
|
||||
for model_config in self._store.all_models():
|
||||
path = self._resolve_model_path(model_config.path)
|
||||
|
@ -128,7 +128,7 @@ class ModelProbe(ModelProbeBase):
|
||||
format_type = "onnx" if model_type == ModelType.ONNX else "diffusers" if model.is_dir() else "checkpoint"
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
|
||||
|
||||
if not probe_class:
|
||||
return None
|
||||
|
||||
@ -138,7 +138,7 @@ class ModelProbe(ModelProbeBase):
|
||||
variant_type = probe.get_variant_type()
|
||||
prediction_type = probe.get_scheduler_prediction_type()
|
||||
format = probe.get_format()
|
||||
|
||||
|
||||
model_info = ModelProbeInfo(
|
||||
model_type=model_type,
|
||||
base_type=base_type,
|
||||
|
@ -11,6 +11,7 @@ from ..config import ModelConfigBase, BaseModelType, ModelType
|
||||
# should match the InvokeAI version when this is first released.
|
||||
CONFIG_FILE_VERSION = "3.1.1"
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
"""Raised on an attempt to add a model with the same key twice."""
|
||||
|
||||
@ -122,4 +123,3 @@ class ModelConfigStore(ABC):
|
||||
Return all the model configs in the database.
|
||||
"""
|
||||
return self.search_by_name()
|
||||
|
||||
|
@ -89,8 +89,9 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
assert self.version == CONFIG_FILE_VERSION, \
|
||||
f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
assert (
|
||||
self.version == CONFIG_FILE_VERSION
|
||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Create sqlite3 tables."""
|
||||
@ -182,9 +183,8 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
("version",CONFIG_FILE_VERSION),
|
||||
("version", CONFIG_FILE_VERSION),
|
||||
)
|
||||
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
|
||||
"""
|
||||
@ -256,7 +256,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
|
||||
def _update_tags(self, key: str, tags: List[str]) -> None:
|
||||
"""Update tags for model with key."""
|
||||
# remove previous tags from this model
|
||||
|
@ -78,8 +78,9 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
if not self._filename.exists():
|
||||
self._initialize_yaml()
|
||||
self._config = OmegaConf.load(self._filename)
|
||||
assert self.version == CONFIG_FILE_VERSION, \
|
||||
f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
assert (
|
||||
self.version == CONFIG_FILE_VERSION
|
||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
|
||||
def _initialize_yaml(self):
|
||||
try:
|
||||
@ -104,7 +105,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return version of this config file/database."""
|
||||
return self._config["__metadata__"].get('version')
|
||||
return self._config["__metadata__"].get("version")
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user