mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: Update model config type names
This commit is contained in:
parent
2d889e133d
commit
9838dda1b7
@ -18,7 +18,7 @@ class ControlNetModel(ModelBase):
|
||||
#model_class: Type
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
class ControlNetModelConfig(ModelConfigBase):
|
||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
@ -82,6 +82,6 @@ class ControlNetModel(ModelBase):
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) != "diffusers":
|
||||
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
|
||||
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
|
||||
else:
|
||||
return model_path
|
||||
|
@ -15,7 +15,7 @@ from ..lora import LoRAModel as LoRAModelRaw
|
||||
class LoRAModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
class LoraModelConfig(ModelConfigBase):
|
||||
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
|
@ -22,12 +22,12 @@ from omegaconf import OmegaConf
|
||||
|
||||
class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
class StableDiffusion1DiffusersModelConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
class StableDiffusion1CheckpointModelConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
@ -107,7 +107,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
) -> str:
|
||||
assert model_path == config.path
|
||||
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
if isinstance(config, cls.CheckpointModelConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion1,
|
||||
model_config=config,
|
||||
@ -120,14 +120,14 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
class StableDiffusion2DiffusersModelConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
class StableDiffusion2CheckpointModelConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
@ -220,7 +220,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
) -> str:
|
||||
assert model_path == config.path
|
||||
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
if isinstance(config, cls.CheckpointModelConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion2,
|
||||
model_config=config,
|
||||
@ -256,7 +256,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||
# TODO: rework
|
||||
def _convert_ckpt_and_cache(
|
||||
version: BaseModelType,
|
||||
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
|
||||
model_config: Union[StableDiffusion1Model.StableDiffusion1CheckpointModelConfig, StableDiffusion2Model.StableDiffusion2CheckpointModelConfig],
|
||||
output_path: str,
|
||||
) -> str:
|
||||
"""
|
||||
@ -281,8 +281,8 @@ def _convert_ckpt_and_cache(
|
||||
prediction_type = SchedulerPredictionType.Epsilon
|
||||
|
||||
elif version == BaseModelType.StableDiffusion2:
|
||||
upcast_attention = config.upcast_attention
|
||||
prediction_type = config.prediction_type
|
||||
upcast_attention = model_config.upcast_attention
|
||||
prediction_type = model_config.prediction_type
|
||||
|
||||
else:
|
||||
raise Exception(f"Unknown model provided: {version}")
|
||||
|
@ -15,7 +15,7 @@ from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
class TextualInversionModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
class TextualInversionModelConfig(ModelConfigBase):
|
||||
format: None
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
import safetensors
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
@ -22,7 +23,7 @@ class VaeModel(ModelBase):
|
||||
#vae_class: Type
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
class VAEModelConfig(ModelConfigBase):
|
||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
|
Loading…
Reference in New Issue
Block a user