added ability to force config class returned by make_config()

This commit is contained in:
Lincoln Stein
2023-08-13 19:08:50 -04:00
parent 155d9fcb13
commit c56fb38855
3 changed files with 29 additions and 7 deletions

View File

@ -22,7 +22,7 @@ Validation errors will raise an InvalidModelConfigException error.
import pydantic
from enum import Enum
from pydantic import BaseModel, Field, Extra
from typing import Optional, Literal, List, Union
from typing import Optional, Literal, List, Union, Type
from pydantic.error_wrappers import ValidationError
from omegaconf.listconfig import ListConfig # to support the yaml backend
@ -238,7 +238,9 @@ class ModelConfigFactory(object):
@classmethod
def make_config(
cls, model_data: Union[dict, ModelConfigBase]
cls,
model_data: Union[dict, ModelConfigBase],
dest_class: Optional[Type] = None,
) -> Union[
MainCheckpointConfig,
MainDiffusersConfig,
@ -247,14 +249,22 @@ class ModelConfigFactory(object):
ONNXSD1Config,
ONNXSD2Config,
]:
"""Return the appropriate config object from raw dict values."""
"""
Return the appropriate config object from raw dict values.
:param model_data: A raw dict corresponding the obect fields to be
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
object, which will be passed through unchanged.
:param dest_class: The config class to be returned. If not provided, will
be selected automatically.
"""
if isinstance(model_data, ModelConfigBase):
return model_data
try:
model_format = model_data.get("model_format")
model_type = model_data.get("model_type")
model_base = model_data.get("base_model")
class_to_return = cls._class_map[model_format][model_type]
class_to_return = dest_class or cls._class_map[model_format][model_type]
if isinstance(class_to_return, dict): # additional level allowed
class_to_return = class_to_return[model_base]
return class_to_return.parse_obj(model_data)

View File

@ -4,7 +4,7 @@ Implementation of ModelConfigStore using a SQLite3 database
Typical usage:
from invokeai.backend.model_management2.storage.yaml import ModelConfigStoreYAML
from invokeai.backend.model_management2.storage.yaml import ModelConfigStoreSQL
store = ModelConfigStoreYAML("./configs/models.yaml")
config = dict(
path='/tmp/pokemon.bin',
@ -13,9 +13,10 @@ Typical usage:
model_type='embedding',
model_format='embedding_file',
author='Anonymous',
tags=['sfw','cartoon']
)
# adding
# adding - the key becomes the model's "id" field
store.add_model('key1', config)
# updating
@ -29,9 +30,14 @@ Typical usage:
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base_model)
assert new_config.id == 'key1'
# deleting
store.del_model('key1')
# searching
configs = store.search_by_tag({'sfw','oss license'})
configs = store.search_by_type(base_model='sd-2', model_type='main')
"""
import threading

View File

@ -13,9 +13,10 @@ Typical usage:
model_type='embedding',
model_format='embedding_file',
author='Anonymous',
tags=['sfw','cartoon']
)
# adding
# adding - the key becomes the model's "id" field
store.add_model('key1', config)
# updating
@ -29,9 +30,14 @@ Typical usage:
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base_model)
assert new_config.id == 'key1'
# deleting
store.del_model('key1')
# searching
configs = store.search_by_tag({'sfw','oss license'})
configs = store.search_by_type(base_model='sd-2', model_type='main')
"""
import threading