mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added ability to force config class returned by make_config()
This commit is contained in:
@ -22,7 +22,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
import pydantic
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
from typing import Optional, Literal, List, Union
|
||||
from typing import Optional, Literal, List, Union, Type
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
from omegaconf.listconfig import ListConfig # to support the yaml backend
|
||||
|
||||
@ -238,7 +238,9 @@ class ModelConfigFactory(object):
|
||||
|
||||
@classmethod
|
||||
def make_config(
|
||||
cls, model_data: Union[dict, ModelConfigBase]
|
||||
cls,
|
||||
model_data: Union[dict, ModelConfigBase],
|
||||
dest_class: Optional[Type] = None,
|
||||
) -> Union[
|
||||
MainCheckpointConfig,
|
||||
MainDiffusersConfig,
|
||||
@ -247,14 +249,22 @@ class ModelConfigFactory(object):
|
||||
ONNXSD1Config,
|
||||
ONNXSD2Config,
|
||||
]:
|
||||
"""Return the appropriate config object from raw dict values."""
|
||||
"""
|
||||
Return the appropriate config object from raw dict values.
|
||||
|
||||
:param model_data: A raw dict corresponding the obect fields to be
|
||||
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
|
||||
object, which will be passed through unchanged.
|
||||
:param dest_class: The config class to be returned. If not provided, will
|
||||
be selected automatically.
|
||||
"""
|
||||
if isinstance(model_data, ModelConfigBase):
|
||||
return model_data
|
||||
try:
|
||||
model_format = model_data.get("model_format")
|
||||
model_type = model_data.get("model_type")
|
||||
model_base = model_data.get("base_model")
|
||||
class_to_return = cls._class_map[model_format][model_type]
|
||||
class_to_return = dest_class or cls._class_map[model_format][model_type]
|
||||
if isinstance(class_to_return, dict): # additional level allowed
|
||||
class_to_return = class_to_return[model_base]
|
||||
return class_to_return.parse_obj(model_data)
|
||||
|
@ -4,7 +4,7 @@ Implementation of ModelConfigStore using a SQLite3 database
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.backend.model_management2.storage.yaml import ModelConfigStoreYAML
|
||||
from invokeai.backend.model_management2.storage.yaml import ModelConfigStoreSQL
|
||||
store = ModelConfigStoreYAML("./configs/models.yaml")
|
||||
config = dict(
|
||||
path='/tmp/pokemon.bin',
|
||||
@ -13,9 +13,10 @@ Typical usage:
|
||||
model_type='embedding',
|
||||
model_format='embedding_file',
|
||||
author='Anonymous',
|
||||
tags=['sfw','cartoon']
|
||||
)
|
||||
|
||||
# adding
|
||||
# adding - the key becomes the model's "id" field
|
||||
store.add_model('key1', config)
|
||||
|
||||
# updating
|
||||
@ -29,9 +30,14 @@ Typical usage:
|
||||
# fetching config
|
||||
new_config = store.get_model('key1')
|
||||
print(new_config.name, new_config.base_model)
|
||||
assert new_config.id == 'key1'
|
||||
|
||||
# deleting
|
||||
store.del_model('key1')
|
||||
|
||||
# searching
|
||||
configs = store.search_by_tag({'sfw','oss license'})
|
||||
configs = store.search_by_type(base_model='sd-2', model_type='main')
|
||||
"""
|
||||
|
||||
import threading
|
||||
|
@ -13,9 +13,10 @@ Typical usage:
|
||||
model_type='embedding',
|
||||
model_format='embedding_file',
|
||||
author='Anonymous',
|
||||
tags=['sfw','cartoon']
|
||||
)
|
||||
|
||||
# adding
|
||||
# adding - the key becomes the model's "id" field
|
||||
store.add_model('key1', config)
|
||||
|
||||
# updating
|
||||
@ -29,9 +30,14 @@ Typical usage:
|
||||
# fetching config
|
||||
new_config = store.get_model('key1')
|
||||
print(new_config.name, new_config.base_model)
|
||||
assert new_config.id == 'key1'
|
||||
|
||||
# deleting
|
||||
store.del_model('key1')
|
||||
|
||||
# searching
|
||||
configs = store.search_by_tag({'sfw','oss license'})
|
||||
configs = store.search_by_type(base_model='sd-2', model_type='main')
|
||||
"""
|
||||
|
||||
import threading
|
||||
|
Reference in New Issue
Block a user