mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
improvements to sdxl support in model manager
- Move SDXL-related models to models/sdxl.py - Create separate base type BaseModelType.StableDiffusionXLRefiner for the refiner models.
This commit is contained in:
parent
130249a2dd
commit
bf2b5b5cd4
@ -101,7 +101,7 @@ class ModelProbe(object):
|
|||||||
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
|
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
|
||||||
and prediction_type==SchedulerPredictionType.VPrediction),
|
and prediction_type==SchedulerPredictionType.VPrediction),
|
||||||
format = format,
|
format = format,
|
||||||
image_size = 1024 if (base_type==BaseModelType.StableDiffusionXL) else \
|
image_size = 1024 if (base_type in {BaseModelType.StableDiffusionXL,BaseModelType.StableDiffusionXLRefiner}) else \
|
||||||
768 if (base_type==BaseModelType.StableDiffusion2 \
|
768 if (base_type==BaseModelType.StableDiffusion2 \
|
||||||
and prediction_type==SchedulerPredictionType.VPrediction ) else \
|
and prediction_type==SchedulerPredictionType.VPrediction ) else \
|
||||||
512
|
512
|
||||||
@ -366,7 +366,9 @@ class PipelineFolderProbe(FolderProbeBase):
|
|||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif unet_conf['cross_attention_dim'] == 1024:
|
elif unet_conf['cross_attention_dim'] == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
elif unet_conf['cross_attention_dim'] in {1280,2048}:
|
elif unet_conf['cross_attention_dim'] == 1280:
|
||||||
|
return BaseModelType.StableDiffusionXLRefiner
|
||||||
|
elif unet_conf['cross_attention_dim'] == 2048:
|
||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown base model for {self.folder_path}')
|
raise ValueError(f'Unknown base model for {self.folder_path}')
|
||||||
|
@ -3,7 +3,8 @@ from enum import Enum
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Literal, get_origin
|
from typing import Literal, get_origin
|
||||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
|
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
|
||||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model, StableDiffusionXLModel
|
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||||
|
from .sdxl import StableDiffusionXLModel
|
||||||
from .vae import VaeModel
|
from .vae import VaeModel
|
||||||
from .lora import LoRAModel
|
from .lora import LoRAModel
|
||||||
from .controlnet import ControlNetModel # TODO:
|
from .controlnet import ControlNetModel # TODO:
|
||||||
@ -32,6 +33,14 @@ MODEL_CLASSES = {
|
|||||||
ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
},
|
},
|
||||||
|
BaseModelType.StableDiffusionXLRefiner: {
|
||||||
|
ModelType.Main: StableDiffusionXLModel,
|
||||||
|
ModelType.Vae: VaeModel,
|
||||||
|
# will not work until support written
|
||||||
|
ModelType.Lora: LoRAModel,
|
||||||
|
ModelType.ControlNet: ControlNetModel,
|
||||||
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
},
|
||||||
#BaseModelType.Kandinsky2_1: {
|
#BaseModelType.Kandinsky2_1: {
|
||||||
# ModelType.Main: Kandinsky2_1Model,
|
# ModelType.Main: Kandinsky2_1Model,
|
||||||
# ModelType.MoVQ: MoVQModel,
|
# ModelType.MoVQ: MoVQModel,
|
||||||
|
@ -22,6 +22,7 @@ class BaseModelType(str, Enum):
|
|||||||
StableDiffusion1 = "sd-1"
|
StableDiffusion1 = "sd-1"
|
||||||
StableDiffusion2 = "sd-2"
|
StableDiffusion2 = "sd-2"
|
||||||
StableDiffusionXL = "sdxl"
|
StableDiffusionXL = "sdxl"
|
||||||
|
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||||
#Kandinsky2_1 = "kandinsky-2.1"
|
#Kandinsky2_1 = "kandinsky-2.1"
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
|
114
invokeai/backend/model_management/models/sdxl.py
Normal file
114
invokeai/backend/model_management/models/sdxl.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import Field
|
||||||
|
from typing import Literal, Optional
|
||||||
|
from .base import (
|
||||||
|
ModelConfigBase,
|
||||||
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
ModelVariantType,
|
||||||
|
DiffusersModel,
|
||||||
|
read_checkpoint_meta,
|
||||||
|
classproperty,
|
||||||
|
)
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
class StableDiffusionXLModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
|
class StableDiffusionXLModel(DiffusersModel):
|
||||||
|
|
||||||
|
# TODO: check that configs overwriten properly
|
||||||
|
class DiffusersConfig(ModelConfigBase):
|
||||||
|
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
|
||||||
|
vae: Optional[str] = Field(None)
|
||||||
|
variant: ModelVariantType
|
||||||
|
|
||||||
|
class CheckpointConfig(ModelConfigBase):
|
||||||
|
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
|
||||||
|
vae: Optional[str] = Field(None)
|
||||||
|
config: str
|
||||||
|
variant: ModelVariantType
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
|
||||||
|
assert model_type == ModelType.Main
|
||||||
|
super().__init__(
|
||||||
|
model_path=model_path,
|
||||||
|
base_model=BaseModelType.StableDiffusionXL,
|
||||||
|
model_type=ModelType.Main,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def probe_config(cls, path: str, **kwargs):
|
||||||
|
model_format = cls.detect_format(path)
|
||||||
|
ckpt_config_path = kwargs.get("config", None)
|
||||||
|
if model_format == StableDiffusionXLModelFormat.Checkpoint:
|
||||||
|
if ckpt_config_path:
|
||||||
|
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||||
|
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
checkpoint = read_checkpoint_meta(path)
|
||||||
|
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||||
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
|
elif model_format == StableDiffusionXLModelFormat.Diffusers:
|
||||||
|
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||||
|
if os.path.exists(unet_config_path):
|
||||||
|
with open(unet_config_path, "r") as f:
|
||||||
|
unet_config = json.loads(f.read())
|
||||||
|
in_channels = unet_config['in_channels']
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
||||||
|
|
||||||
|
if in_channels == 9:
|
||||||
|
variant = ModelVariantType.Inpaint
|
||||||
|
elif in_channels == 5:
|
||||||
|
variant = ModelVariantType.Depth
|
||||||
|
elif in_channels == 4:
|
||||||
|
variant = ModelVariantType.Normal
|
||||||
|
else:
|
||||||
|
raise Exception("Unkown stable diffusion 2.* model format")
|
||||||
|
|
||||||
|
if ckpt_config_path is None:
|
||||||
|
# TO DO: implement picking
|
||||||
|
pass
|
||||||
|
|
||||||
|
return cls.create_config(
|
||||||
|
path=path,
|
||||||
|
model_format=model_format,
|
||||||
|
|
||||||
|
config=ckpt_config_path,
|
||||||
|
variant=variant,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, model_path: str):
|
||||||
|
if os.path.isdir(model_path):
|
||||||
|
return StableDiffusionXLModelFormat.Diffusers
|
||||||
|
else:
|
||||||
|
return StableDiffusionXLModelFormat.Checkpoint
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_if_required(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
config: ModelConfigBase,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> str:
|
||||||
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
|
raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported')
|
||||||
|
else:
|
||||||
|
return model_path
|
@ -5,14 +5,11 @@ from pydantic import Field
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
DiffusersModel,
|
DiffusersModel,
|
||||||
SchedulerPredictionType,
|
|
||||||
SilenceWarnings,
|
SilenceWarnings,
|
||||||
read_checkpoint_meta,
|
read_checkpoint_meta,
|
||||||
classproperty,
|
classproperty,
|
||||||
@ -222,105 +219,6 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
class StableDiffusionXLModelFormat(str, Enum):
|
|
||||||
Checkpoint = "checkpoint"
|
|
||||||
Diffusers = "diffusers"
|
|
||||||
|
|
||||||
class StableDiffusionXLModel(DiffusersModel):
|
|
||||||
|
|
||||||
# TODO: check that configs overwriten properly
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
variant: ModelVariantType
|
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
config: str
|
|
||||||
variant: ModelVariantType
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert base_model == BaseModelType.StableDiffusionXL
|
|
||||||
assert model_type == ModelType.Main
|
|
||||||
super().__init__(
|
|
||||||
model_path=model_path,
|
|
||||||
base_model=BaseModelType.StableDiffusionXL,
|
|
||||||
model_type=ModelType.Main,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def probe_config(cls, path: str, **kwargs):
|
|
||||||
model_format = cls.detect_format(path)
|
|
||||||
ckpt_config_path = kwargs.get("config", None)
|
|
||||||
if model_format == StableDiffusionXLModelFormat.Checkpoint:
|
|
||||||
if ckpt_config_path:
|
|
||||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
|
||||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
|
||||||
|
|
||||||
else:
|
|
||||||
checkpoint = read_checkpoint_meta(path)
|
|
||||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
|
||||||
|
|
||||||
elif model_format == StableDiffusionXLModelFormat.Diffusers:
|
|
||||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
|
||||||
if os.path.exists(unet_config_path):
|
|
||||||
with open(unet_config_path, "r") as f:
|
|
||||||
unet_config = json.loads(f.read())
|
|
||||||
in_channels = unet_config['in_channels']
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
|
||||||
|
|
||||||
if in_channels == 9:
|
|
||||||
variant = ModelVariantType.Inpaint
|
|
||||||
elif in_channels == 5:
|
|
||||||
variant = ModelVariantType.Depth
|
|
||||||
elif in_channels == 4:
|
|
||||||
variant = ModelVariantType.Normal
|
|
||||||
else:
|
|
||||||
raise Exception("Unkown stable diffusion 2.* model format")
|
|
||||||
|
|
||||||
if ckpt_config_path is None:
|
|
||||||
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusionXL, variant)
|
|
||||||
|
|
||||||
return cls.create_config(
|
|
||||||
path=path,
|
|
||||||
model_format=model_format,
|
|
||||||
|
|
||||||
config=ckpt_config_path,
|
|
||||||
variant=variant,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, model_path: str):
|
|
||||||
if os.path.isdir(model_path):
|
|
||||||
return StableDiffusionXLModelFormat.Diffusers
|
|
||||||
else:
|
|
||||||
return StableDiffusionXLModelFormat.Checkpoint
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
|
||||||
raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported')
|
|
||||||
else:
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
|
|
||||||
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||||
ckpt_configs = {
|
ckpt_configs = {
|
||||||
BaseModelType.StableDiffusion1: {
|
BaseModelType.StableDiffusion1: {
|
||||||
@ -355,7 +253,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
|||||||
# Note that convert_ckpt_to_diffuses does not currently support conversion of SDXL models
|
# Note that convert_ckpt_to_diffuses does not currently support conversion of SDXL models
|
||||||
def _convert_ckpt_and_cache(
|
def _convert_ckpt_and_cache(
|
||||||
version: BaseModelType,
|
version: BaseModelType,
|
||||||
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig, StableDiffusionXLModel.CheckpointConfig],
|
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
|
||||||
output_path: str,
|
output_path: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user