mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: Update model config type names
This commit is contained in:
parent
4cbc802e36
commit
67d05d2066
@ -18,7 +18,7 @@ class ControlNetModel(ModelBase):
|
|||||||
#model_class: Type
|
#model_class: Type
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class ControlNetModelConfig(ModelConfigBase):
|
||||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
@ -82,6 +82,6 @@ class ControlNetModel(ModelBase):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) != "diffusers":
|
if cls.detect_format(model_path) != "diffusers":
|
||||||
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
|
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
@ -15,7 +15,7 @@ from ..lora import LoRAModel as LoRAModelRaw
|
|||||||
class LoRAModel(ModelBase):
|
class LoRAModel(ModelBase):
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class LoraModelConfig(ModelConfigBase):
|
||||||
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
@ -22,12 +22,12 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
class StableDiffusion1Model(DiffusersModel):
|
class StableDiffusion1Model(DiffusersModel):
|
||||||
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class StableDiffusion1DiffusersModelConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
format: Literal["diffusers"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class StableDiffusion1CheckpointModelConfig(ModelConfigBase):
|
||||||
format: Literal["checkpoint"]
|
format: Literal["checkpoint"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
@ -107,7 +107,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
) -> str:
|
) -> str:
|
||||||
assert model_path == config.path
|
assert model_path == config.path
|
||||||
|
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
if isinstance(config, cls.CheckpointModelConfig):
|
||||||
return _convert_ckpt_and_cache(
|
return _convert_ckpt_and_cache(
|
||||||
version=BaseModelType.StableDiffusion1,
|
version=BaseModelType.StableDiffusion1,
|
||||||
model_config=config,
|
model_config=config,
|
||||||
@ -120,14 +120,14 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
class StableDiffusion2Model(DiffusersModel):
|
class StableDiffusion2Model(DiffusersModel):
|
||||||
|
|
||||||
# TODO: check that configs overwriten properly
|
# TODO: check that configs overwriten properly
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class StableDiffusion2DiffusersModelConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
format: Literal["diffusers"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType
|
||||||
upcast_attention: bool
|
upcast_attention: bool
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class StableDiffusion2CheckpointModelConfig(ModelConfigBase):
|
||||||
format: Literal["checkpoint"]
|
format: Literal["checkpoint"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
@ -220,7 +220,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
) -> str:
|
) -> str:
|
||||||
assert model_path == config.path
|
assert model_path == config.path
|
||||||
|
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
if isinstance(config, cls.CheckpointModelConfig):
|
||||||
return _convert_ckpt_and_cache(
|
return _convert_ckpt_and_cache(
|
||||||
version=BaseModelType.StableDiffusion2,
|
version=BaseModelType.StableDiffusion2,
|
||||||
model_config=config,
|
model_config=config,
|
||||||
@ -256,7 +256,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
|||||||
# TODO: rework
|
# TODO: rework
|
||||||
def _convert_ckpt_and_cache(
|
def _convert_ckpt_and_cache(
|
||||||
version: BaseModelType,
|
version: BaseModelType,
|
||||||
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
|
model_config: Union[StableDiffusion1Model.StableDiffusion1CheckpointModelConfig, StableDiffusion2Model.StableDiffusion2CheckpointModelConfig],
|
||||||
output_path: str,
|
output_path: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@ -281,8 +281,8 @@ def _convert_ckpt_and_cache(
|
|||||||
prediction_type = SchedulerPredictionType.Epsilon
|
prediction_type = SchedulerPredictionType.Epsilon
|
||||||
|
|
||||||
elif version == BaseModelType.StableDiffusion2:
|
elif version == BaseModelType.StableDiffusion2:
|
||||||
upcast_attention = config.upcast_attention
|
upcast_attention = model_config.upcast_attention
|
||||||
prediction_type = config.prediction_type
|
prediction_type = model_config.prediction_type
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown model provided: {version}")
|
raise Exception(f"Unknown model provided: {version}")
|
||||||
|
@ -14,7 +14,7 @@ from ..lora import TextualInversionModel as TextualInversionModelRaw
|
|||||||
class TextualInversionModel(ModelBase):
|
class TextualInversionModel(ModelBase):
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class TextualInversionModelConfig(ModelConfigBase):
|
||||||
format: None
|
format: None
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import safetensors
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Union, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
@ -22,7 +23,7 @@ class VaeModel(ModelBase):
|
|||||||
#vae_class: Type
|
#vae_class: Type
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class VAEModelConfig(ModelConfigBase):
|
||||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
Loading…
Reference in New Issue
Block a user