diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index e3c99c5644..5c3f1c6e8f 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -51,6 +51,7 @@ class BaseModelType(str, Enum): StableDiffusion2 = "sd-2" StableDiffusionXL = "sdxl" StableDiffusionXLRefiner = "sdxl-refiner" + StableDiffusion3 = "sd-3" # Kandinsky2_1 = "kandinsky-2.1" @@ -74,8 +75,10 @@ class SubModelType(str, Enum): UNet = "unet" TextEncoder = "text_encoder" TextEncoder2 = "text_encoder_2" + TextEncoder3 = "text_encoder_3" Tokenizer = "tokenizer" Tokenizer2 = "tokenizer_2" + Tokenizer3 = "tokenizer_3" VAE = "vae" VAEDecoder = "vae_decoder" VAEEncoder = "vae_encoder" diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 8f33e4b49f..ca9d037234 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -100,6 +100,7 @@ class ModelProbe(object): "StableDiffusionXLImg2ImgPipeline": ModelType.Main, "StableDiffusionXLInpaintPipeline": ModelType.Main, "LatentConsistencyModelPipeline": ModelType.Main, + "StableDiffusion3Pipeline": ModelType.Main, "AutoencoderKL": ModelType.VAE, "AutoencoderTiny": ModelType.VAE, "ControlNetModel": ModelType.ControlNet, @@ -298,10 +299,13 @@ class ModelProbe(object): return possible_conf.absolute() if model_type is ModelType.Main: - config_file = LEGACY_CONFIGS[base_type][variant_type] - if isinstance(config_file, dict): # need another tier for sd-2.x models - config_file = config_file[prediction_type] - config_file = f"stable-diffusion/{config_file}" + if base_type is BaseModelType.StableDiffusion3: + config_file = "stable-diffusion/v3-inference.yaml" + else: + config_file = LEGACY_CONFIGS[base_type][variant_type] + if isinstance(config_file, dict): # need another tier for sd-2.x models + config_file = config_file[prediction_type] + config_file = f"stable-diffusion/{config_file}" elif model_type is ModelType.ControlNet: config_file = ( "controlnet/cldm_v15.yaml" @@ -374,7 +378,7 @@ def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[Con def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]: if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2: return MainModelDefaultSettings(width=512, height=512) - elif model_base is BaseModelType.StableDiffusionXL: + elif model_base in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusion3]: return MainModelDefaultSettings(width=1024, height=1024) # We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models. return None @@ -398,7 +402,10 @@ class CheckpointProbeBase(ProbeBase): if model_type != ModelType.Main: return ModelVariantType.Normal state_dict = self.checkpoint.get("state_dict") or self.checkpoint - in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] + key = "model.diffusion_model.input_blocks.0.0.weight" + if key not in state_dict: + return ModelVariantType.Normal + in_channels = state_dict[key].shape[1] if in_channels == 9: return ModelVariantType.Inpaint elif in_channels == 5: @@ -425,6 +432,9 @@ class PipelineCheckpointProbe(CheckpointProbeBase): return BaseModelType.StableDiffusionXL elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: return BaseModelType.StableDiffusionXLRefiner + key_name = "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" + if key_name in state_dict: + return BaseModelType.StableDiffusion3 else: raise InvalidModelConfigException("Cannot determine base type") @@ -588,6 +598,10 @@ class FolderProbeBase(ProbeBase): class PipelineFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: + with open(self.model_path / "model_index.json", "r") as file: + index_conf = json.load(file) + if index_conf.get("_class_name") == "StableDiffusion3Pipeline": + return BaseModelType.StableDiffusion3 with open(self.model_path / "unet" / "config.json", "r") as file: unet_conf = json.load(file) if unet_conf["cross_attention_dim"] == 768: diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 7e362fe958..9807754a33 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -3,7 +3,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union import diffusers import torch from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import FromOriginalControlNetMixin + +# The following import is +# generating import errors with diffusers 028.2 +# tried diffusers.loaders.controlnet import FromOriginalControlNetMixin, but this +# fails as well +# from diffusers.loaders import FromOriginalControlNetMixin from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module from diffusers.models.embeddings import ( @@ -32,7 +37,7 @@ from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger(__name__) -class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin): +class ControlNetModel(ModelMixin, ConfigMixin): """ A ControlNet model. diff --git a/pyproject.toml b/pyproject.toml index fcc0aff60c..bf983a0c8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel==2.0.2", "controlnet-aux==0.0.7", - "diffusers[torch]==0.27.2", + "diffusers[torch]", "invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids "mediapipe==0.10.7", # needed for "mediapipeface" controlnet model "numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal() @@ -47,11 +47,11 @@ dependencies = [ "pytorch-lightning==2.1.3", "safetensors==0.4.3", "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 - "torch==2.2.2", + "torch", "torchmetrics==0.11.4", "torchsde==0.2.6", - "torchvision==0.17.2", - "transformers==4.41.1", + "torchvision", + "transformers", # Core application dependencies, pinned for reproducible builds. "fastapi-events==0.11.0",