mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add draft SD3 probing; there is an issue with FromOriginalControlNetMixin in backend.util.hotfixes due to new diffusers
This commit is contained in:
parent
568a4844f7
commit
002f8242a1
@ -51,6 +51,7 @@ class BaseModelType(str, Enum):
|
|||||||
StableDiffusion2 = "sd-2"
|
StableDiffusion2 = "sd-2"
|
||||||
StableDiffusionXL = "sdxl"
|
StableDiffusionXL = "sdxl"
|
||||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||||
|
StableDiffusion3 = "sd-3"
|
||||||
# Kandinsky2_1 = "kandinsky-2.1"
|
# Kandinsky2_1 = "kandinsky-2.1"
|
||||||
|
|
||||||
|
|
||||||
@ -74,8 +75,10 @@ class SubModelType(str, Enum):
|
|||||||
UNet = "unet"
|
UNet = "unet"
|
||||||
TextEncoder = "text_encoder"
|
TextEncoder = "text_encoder"
|
||||||
TextEncoder2 = "text_encoder_2"
|
TextEncoder2 = "text_encoder_2"
|
||||||
|
TextEncoder3 = "text_encoder_3"
|
||||||
Tokenizer = "tokenizer"
|
Tokenizer = "tokenizer"
|
||||||
Tokenizer2 = "tokenizer_2"
|
Tokenizer2 = "tokenizer_2"
|
||||||
|
Tokenizer3 = "tokenizer_3"
|
||||||
VAE = "vae"
|
VAE = "vae"
|
||||||
VAEDecoder = "vae_decoder"
|
VAEDecoder = "vae_decoder"
|
||||||
VAEEncoder = "vae_encoder"
|
VAEEncoder = "vae_encoder"
|
||||||
|
@ -100,6 +100,7 @@ class ModelProbe(object):
|
|||||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||||
|
"StableDiffusion3Pipeline": ModelType.Main,
|
||||||
"AutoencoderKL": ModelType.VAE,
|
"AutoencoderKL": ModelType.VAE,
|
||||||
"AutoencoderTiny": ModelType.VAE,
|
"AutoencoderTiny": ModelType.VAE,
|
||||||
"ControlNetModel": ModelType.ControlNet,
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
@ -298,6 +299,9 @@ class ModelProbe(object):
|
|||||||
return possible_conf.absolute()
|
return possible_conf.absolute()
|
||||||
|
|
||||||
if model_type is ModelType.Main:
|
if model_type is ModelType.Main:
|
||||||
|
if base_type is BaseModelType.StableDiffusion3:
|
||||||
|
config_file = "stable-diffusion/v3-inference.yaml"
|
||||||
|
else:
|
||||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||||
config_file = config_file[prediction_type]
|
config_file = config_file[prediction_type]
|
||||||
@ -374,7 +378,7 @@ def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[Con
|
|||||||
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
|
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
|
||||||
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
|
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
|
||||||
return MainModelDefaultSettings(width=512, height=512)
|
return MainModelDefaultSettings(width=512, height=512)
|
||||||
elif model_base is BaseModelType.StableDiffusionXL:
|
elif model_base in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusion3]:
|
||||||
return MainModelDefaultSettings(width=1024, height=1024)
|
return MainModelDefaultSettings(width=1024, height=1024)
|
||||||
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
|
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
|
||||||
return None
|
return None
|
||||||
@ -398,7 +402,10 @@ class CheckpointProbeBase(ProbeBase):
|
|||||||
if model_type != ModelType.Main:
|
if model_type != ModelType.Main:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
key = "model.diffusion_model.input_blocks.0.0.weight"
|
||||||
|
if key not in state_dict:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
in_channels = state_dict[key].shape[1]
|
||||||
if in_channels == 9:
|
if in_channels == 9:
|
||||||
return ModelVariantType.Inpaint
|
return ModelVariantType.Inpaint
|
||||||
elif in_channels == 5:
|
elif in_channels == 5:
|
||||||
@ -425,6 +432,9 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
|||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||||
return BaseModelType.StableDiffusionXLRefiner
|
return BaseModelType.StableDiffusionXLRefiner
|
||||||
|
key_name = "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight"
|
||||||
|
if key_name in state_dict:
|
||||||
|
return BaseModelType.StableDiffusion3
|
||||||
else:
|
else:
|
||||||
raise InvalidModelConfigException("Cannot determine base type")
|
raise InvalidModelConfigException("Cannot determine base type")
|
||||||
|
|
||||||
@ -588,6 +598,10 @@ class FolderProbeBase(ProbeBase):
|
|||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
with open(self.model_path / "model_index.json", "r") as file:
|
||||||
|
index_conf = json.load(file)
|
||||||
|
if index_conf.get("_class_name") == "StableDiffusion3Pipeline":
|
||||||
|
return BaseModelType.StableDiffusion3
|
||||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||||
unet_conf = json.load(file)
|
unet_conf = json.load(file)
|
||||||
if unet_conf["cross_attention_dim"] == 768:
|
if unet_conf["cross_attention_dim"] == 768:
|
||||||
|
@ -3,7 +3,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import diffusers
|
import diffusers
|
||||||
import torch
|
import torch
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
from diffusers.loaders import FromOriginalControlNetMixin
|
|
||||||
|
# The following import is
|
||||||
|
# generating import errors with diffusers 028.2
|
||||||
|
# tried diffusers.loaders.controlnet import FromOriginalControlNetMixin, but this
|
||||||
|
# fails as well
|
||||||
|
# from diffusers.loaders import FromOriginalControlNetMixin
|
||||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||||
from diffusers.models.embeddings import (
|
from diffusers.models.embeddings import (
|
||||||
@ -32,7 +37,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
|||||||
logger = InvokeAILogger.get_logger(__name__)
|
logger = InvokeAILogger.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
A ControlNet model.
|
A ControlNet model.
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ dependencies = [
|
|||||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel==2.0.2",
|
"compel==2.0.2",
|
||||||
"controlnet-aux==0.0.7",
|
"controlnet-aux==0.0.7",
|
||||||
"diffusers[torch]==0.27.2",
|
"diffusers[torch]",
|
||||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||||
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
||||||
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
||||||
@ -47,11 +47,11 @@ dependencies = [
|
|||||||
"pytorch-lightning==2.1.3",
|
"pytorch-lightning==2.1.3",
|
||||||
"safetensors==0.4.3",
|
"safetensors==0.4.3",
|
||||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||||
"torch==2.2.2",
|
"torch",
|
||||||
"torchmetrics==0.11.4",
|
"torchmetrics==0.11.4",
|
||||||
"torchsde==0.2.6",
|
"torchsde==0.2.6",
|
||||||
"torchvision==0.17.2",
|
"torchvision",
|
||||||
"transformers==4.41.1",
|
"transformers",
|
||||||
|
|
||||||
# Core application dependencies, pinned for reproducible builds.
|
# Core application dependencies, pinned for reproducible builds.
|
||||||
"fastapi-events==0.11.0",
|
"fastapi-events==0.11.0",
|
||||||
|
Loading…
Reference in New Issue
Block a user