add draft SD3 probing; there is an issue with FromOriginalControlNetMixin in backend.util.hotfixes due to new diffusers

This commit is contained in:
Lincoln Stein 2024-06-12 22:44:34 -04:00
parent 568a4844f7
commit 002f8242a1
4 changed files with 34 additions and 12 deletions

View File

@ -51,6 +51,7 @@ class BaseModelType(str, Enum):
StableDiffusion2 = "sd-2" StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl" StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner" StableDiffusionXLRefiner = "sdxl-refiner"
StableDiffusion3 = "sd-3"
# Kandinsky2_1 = "kandinsky-2.1" # Kandinsky2_1 = "kandinsky-2.1"
@ -74,8 +75,10 @@ class SubModelType(str, Enum):
UNet = "unet" UNet = "unet"
TextEncoder = "text_encoder" TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2" TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer" Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2" Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
VAE = "vae" VAE = "vae"
VAEDecoder = "vae_decoder" VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder" VAEEncoder = "vae_encoder"

View File

@ -100,6 +100,7 @@ class ModelProbe(object):
"StableDiffusionXLImg2ImgPipeline": ModelType.Main, "StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main, "StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main, "LatentConsistencyModelPipeline": ModelType.Main,
"StableDiffusion3Pipeline": ModelType.Main,
"AutoencoderKL": ModelType.VAE, "AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.VAE, "AutoencoderTiny": ModelType.VAE,
"ControlNetModel": ModelType.ControlNet, "ControlNetModel": ModelType.ControlNet,
@ -298,6 +299,9 @@ class ModelProbe(object):
return possible_conf.absolute() return possible_conf.absolute()
if model_type is ModelType.Main: if model_type is ModelType.Main:
if base_type is BaseModelType.StableDiffusion3:
config_file = "stable-diffusion/v3-inference.yaml"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type] config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type] config_file = config_file[prediction_type]
@ -374,7 +378,7 @@ def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[Con
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]: def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2: if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
return MainModelDefaultSettings(width=512, height=512) return MainModelDefaultSettings(width=512, height=512)
elif model_base is BaseModelType.StableDiffusionXL: elif model_base in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusion3]:
return MainModelDefaultSettings(width=1024, height=1024) return MainModelDefaultSettings(width=1024, height=1024)
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models. # We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
return None return None
@ -398,7 +402,10 @@ class CheckpointProbeBase(ProbeBase):
if model_type != ModelType.Main: if model_type != ModelType.Main:
return ModelVariantType.Normal return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] key = "model.diffusion_model.input_blocks.0.0.weight"
if key not in state_dict:
return ModelVariantType.Normal
in_channels = state_dict[key].shape[1]
if in_channels == 9: if in_channels == 9:
return ModelVariantType.Inpaint return ModelVariantType.Inpaint
elif in_channels == 5: elif in_channels == 5:
@ -425,6 +432,9 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusionXL return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner return BaseModelType.StableDiffusionXLRefiner
key_name = "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight"
if key_name in state_dict:
return BaseModelType.StableDiffusion3
else: else:
raise InvalidModelConfigException("Cannot determine base type") raise InvalidModelConfigException("Cannot determine base type")
@ -588,6 +598,10 @@ class FolderProbeBase(ProbeBase):
class PipelineFolderProbe(FolderProbeBase): class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
with open(self.model_path / "model_index.json", "r") as file:
index_conf = json.load(file)
if index_conf.get("_class_name") == "StableDiffusion3Pipeline":
return BaseModelType.StableDiffusion3
with open(self.model_path / "unet" / "config.json", "r") as file: with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file) unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768: if unet_conf["cross_attention_dim"] == 768:

View File

@ -3,7 +3,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import diffusers import diffusers
import torch import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlNetMixin
# The following import is
# generating import errors with diffusers 028.2
# tried diffusers.loaders.controlnet import FromOriginalControlNetMixin, but this
# fails as well
# from diffusers.loaders import FromOriginalControlNetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from diffusers.models.embeddings import ( from diffusers.models.embeddings import (
@ -32,7 +37,7 @@ from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger(__name__) logger = InvokeAILogger.get_logger(__name__)
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin): class ControlNetModel(ModelMixin, ConfigMixin):
""" """
A ControlNet model. A ControlNet model.

View File

@ -37,7 +37,7 @@ dependencies = [
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.2", "compel==2.0.2",
"controlnet-aux==0.0.7", "controlnet-aux==0.0.7",
"diffusers[torch]==0.27.2", "diffusers[torch]",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids "invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model "mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal() "numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
@ -47,11 +47,11 @@ dependencies = [
"pytorch-lightning==2.1.3", "pytorch-lightning==2.1.3",
"safetensors==0.4.3", "safetensors==0.4.3",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch==2.2.2", "torch",
"torchmetrics==0.11.4", "torchmetrics==0.11.4",
"torchsde==0.2.6", "torchsde==0.2.6",
"torchvision==0.17.2", "torchvision",
"transformers==4.41.1", "transformers",
# Core application dependencies, pinned for reproducible builds. # Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.11.0", "fastapi-events==0.11.0",