mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
sync pydantic and sql field names; merge routes
This commit is contained in:
@ -7,8 +7,8 @@ Typical usage:
|
||||
from invokeai.backend.model_manager import ModelConfigFactory
|
||||
raw = dict(path='models/sd-1/main/foo.ckpt',
|
||||
name='foo',
|
||||
base_model='sd-1',
|
||||
model_type='main',
|
||||
base='sd-1',
|
||||
type='main',
|
||||
config='configs/stable-diffusion/v1-inference.yaml',
|
||||
variant='normal',
|
||||
format='checkpoint'
|
||||
@ -103,7 +103,7 @@ class ModelConfigBase(BaseModel):
|
||||
|
||||
path: str
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
format: ModelFormat
|
||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||
@ -181,20 +181,29 @@ class MainConfig(ModelConfigBase):
|
||||
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
ztsnr_training: bool = False
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfig, MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
# Note that we do not need prediction_type or upcast_attention here
|
||||
# because they are provided in the checkpoint's own config file.
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfig, MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD1Config(MainConfig):
|
||||
"""Model config for ONNX format models based on sd-1."""
|
||||
|
||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD2Config(MainConfig):
|
||||
@ -202,8 +211,8 @@ class ONNXSD2Config(MainConfig):
|
||||
|
||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
# No yaml config file for ONNX, so these are part of config
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
|
||||
upcast_attention: bool = True
|
||||
|
||||
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
@ -305,7 +314,7 @@ class ModelConfigFactory(object):
|
||||
try:
|
||||
format = model_data.get("format")
|
||||
type = model_data.get("type")
|
||||
model_base = model_data.get("base_model")
|
||||
model_base = model_data.get("base")
|
||||
class_to_return = dest_class or cls._class_map[format][type]
|
||||
if isinstance(class_to_return, dict): # additional level allowed
|
||||
class_to_return = class_to_return[model_base]
|
||||
|
@ -50,7 +50,7 @@ class Migrate:
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base_model"] = BaseModelType(base_type)
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
stanza["type"] = ModelType(model_type)
|
||||
stanza["name"] = model_name
|
||||
stanza["original_hash"] = hash
|
||||
|
Reference in New Issue
Block a user