mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add draft SD3 probing; there is an issue with FromOriginalControlNetMixin in backend.util.hotfixes due to new diffusers
This commit is contained in:
parent
568a4844f7
commit
002f8242a1
@ -51,6 +51,7 @@ class BaseModelType(str, Enum):
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
StableDiffusion3 = "sd-3"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
@ -74,8 +75,10 @@ class SubModelType(str, Enum):
|
||||
UNet = "unet"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
TextEncoder3 = "text_encoder_3"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Tokenizer3 = "tokenizer_3"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
|
@ -100,6 +100,7 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"StableDiffusion3Pipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.VAE,
|
||||
"AutoencoderTiny": ModelType.VAE,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
@ -298,10 +299,13 @@ class ModelProbe(object):
|
||||
return possible_conf.absolute()
|
||||
|
||||
if model_type is ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
if base_type is BaseModelType.StableDiffusion3:
|
||||
config_file = "stable-diffusion/v3-inference.yaml"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
elif model_type is ModelType.ControlNet:
|
||||
config_file = (
|
||||
"controlnet/cldm_v15.yaml"
|
||||
@ -374,7 +378,7 @@ def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[Con
|
||||
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
|
||||
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
|
||||
return MainModelDefaultSettings(width=512, height=512)
|
||||
elif model_base is BaseModelType.StableDiffusionXL:
|
||||
elif model_base in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusion3]:
|
||||
return MainModelDefaultSettings(width=1024, height=1024)
|
||||
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
|
||||
return None
|
||||
@ -398,7 +402,10 @@ class CheckpointProbeBase(ProbeBase):
|
||||
if model_type != ModelType.Main:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
key = "model.diffusion_model.input_blocks.0.0.weight"
|
||||
if key not in state_dict:
|
||||
return ModelVariantType.Normal
|
||||
in_channels = state_dict[key].shape[1]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
@ -425,6 +432,9 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
key_name = "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight"
|
||||
if key_name in state_dict:
|
||||
return BaseModelType.StableDiffusion3
|
||||
else:
|
||||
raise InvalidModelConfigException("Cannot determine base type")
|
||||
|
||||
@ -588,6 +598,10 @@ class FolderProbeBase(ProbeBase):
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
with open(self.model_path / "model_index.json", "r") as file:
|
||||
index_conf = json.load(file)
|
||||
if index_conf.get("_class_name") == "StableDiffusion3Pipeline":
|
||||
return BaseModelType.StableDiffusion3
|
||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
|
@ -3,7 +3,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import diffusers
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalControlNetMixin
|
||||
|
||||
# The following import is
|
||||
# generating import errors with diffusers 028.2
|
||||
# tried diffusers.loaders.controlnet import FromOriginalControlNetMixin, but this
|
||||
# fails as well
|
||||
# from diffusers.loaders import FromOriginalControlNetMixin
|
||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
from diffusers.models.embeddings import (
|
||||
@ -32,7 +37,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
logger = InvokeAILogger.get_logger(__name__)
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
|
@ -37,7 +37,7 @@ dependencies = [
|
||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==2.0.2",
|
||||
"controlnet-aux==0.0.7",
|
||||
"diffusers[torch]==0.27.2",
|
||||
"diffusers[torch]",
|
||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
||||
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
||||
@ -47,11 +47,11 @@ dependencies = [
|
||||
"pytorch-lightning==2.1.3",
|
||||
"safetensors==0.4.3",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch==2.2.2",
|
||||
"torch",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchsde==0.2.6",
|
||||
"torchvision==0.17.2",
|
||||
"transformers==4.41.1",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events==0.11.0",
|
||||
|
Loading…
Reference in New Issue
Block a user