sync pydantic and sql field names; merge routes

This commit is contained in:
Lincoln Stein
2023-11-06 18:08:57 -05:00
parent 55f8865524
commit ce22c0fbaa
6 changed files with 82 additions and 45 deletions

View File

@ -7,8 +7,8 @@ Typical usage:
from invokeai.backend.model_manager import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo',
base_model='sd-1',
model_type='main',
base='sd-1',
type='main',
config='configs/stable-diffusion/v1-inference.yaml',
variant='normal',
format='checkpoint'
@ -103,7 +103,7 @@ class ModelConfigBase(BaseModel):
path: str
name: str
base_model: BaseModelType
base: BaseModelType
type: ModelType
format: ModelFormat
key: str = Field(description="unique key for model", default="<NOKEY>")
@ -181,20 +181,29 @@ class MainConfig(ModelConfigBase):
vae: Optional[str] = Field(None)
variant: ModelVariantType = ModelVariantType.Normal
ztsnr_training: bool = False
class MainCheckpointConfig(CheckpointConfig, MainConfig):
"""Model config for main checkpoint models."""
# Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file.
class MainDiffusersConfig(DiffusersConfig, MainConfig):
"""Model config for main diffusers models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD1Config(MainConfig):
"""Model config for ONNX format models based on sd-1."""
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD2Config(MainConfig):
@ -202,8 +211,8 @@ class ONNXSD2Config(MainConfig):
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config
prediction_type: SchedulerPredictionType
upcast_attention: bool
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
upcast_attention: bool = True
class IPAdapterConfig(ModelConfigBase):
@ -305,7 +314,7 @@ class ModelConfigFactory(object):
try:
format = model_data.get("format")
type = model_data.get("type")
model_base = model_data.get("base_model")
model_base = model_data.get("base")
class_to_return = dest_class or cls._class_map[format][type]
if isinstance(class_to_return, dict): # additional level allowed
class_to_return = class_to_return[model_base]

View File

@ -50,7 +50,7 @@ class Migrate:
hash = FastModelHash.hash(self.config.models_path / stanza.path)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base_model"] = BaseModelType(base_type)
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash