diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 71d5ba779c..e128792d70 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -116,16 +116,15 @@ class CompelInvocation(BaseInvocation): text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=torch_dtype, - truncate_long_prompts=True, + truncate_long_prompts=False, ) conjunction = Compel.parse_prompt_string(self.prompt) - prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] if context.services.configuration.log_tokenization: - log_tokenization_for_prompt_object(prompt, tokenizer) + log_tokenization_for_conjunction(conjunction, tokenizer) - c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) + c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), @@ -231,7 +230,7 @@ class SDXLPromptInvocationBase: text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=torch_dtype, - truncate_long_prompts=True, # TODO: + truncate_long_prompts=False, # TODO: returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip requires_pooled=get_pooled, ) @@ -240,8 +239,7 @@ class SDXLPromptInvocationBase: if context.services.configuration.log_tokenization: # TODO: better logging for and syntax - for prompt_obj in conjunction.prompts: - log_tokenization_for_prompt_object(prompt_obj, tokenizer) + log_tokenization_for_conjunction(conjunction, tokenizer) # TODO: ask for optimizations? to not run text_encoder twice c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) diff --git a/invokeai/backend/model_management/seamless.py b/invokeai/backend/model_management/seamless.py index 54885769ad..7138f2e123 100644 --- a/invokeai/backend/model_management/seamless.py +++ b/invokeai/backend/model_management/seamless.py @@ -71,7 +71,6 @@ def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axe """ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - print(f"applied - {m_name}") m.asymmetric_padding_mode = {} m.asymmetric_padding = {} m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 3d64d8a42c..983d0b7601 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -1,11 +1,11 @@ from typing import Any, Dict, List, Optional, Tuple, Union +import diffusers import torch -from torch import nn - from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalControlnetMixin from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module from diffusers.models.embeddings import ( TextImageProjection, TextImageTimeEmbedding, @@ -14,16 +14,9 @@ from diffusers.models.embeddings import ( Timesteps, ) from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.unet_2d_blocks import ( - CrossAttnDownBlock2D, - DownBlock2D, - UNetMidBlock2DCrossAttn, - get_down_block, -) +from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block from diffusers.models.unet_2d_condition import UNet2DConditionModel - -import diffusers -from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module +from torch import nn from invokeai.backend.util.logging import InvokeAILogger @@ -45,7 +38,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", \ + "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): @@ -147,7 +141,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # when this library was created... + # The incorrect naming was only discovered much ... + # later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim @@ -155,17 +151,20 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): # Check inputs if len(block_out_channels) != len(down_block_types): raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + f"Must provide the same number of `block_out_channels` as `down_block_types`. \ + `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + f"Must provide the same number of `only_cross_attention` as `down_block_types`. \ + `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + f"Must provide the same number of `num_attention_heads` as `down_block_types`. \ + `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if isinstance(transformer_layers_per_block, int): @@ -202,7 +201,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # they are set to `cross_attention_dim` here as this is exactly the required dimension ... + # for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, @@ -250,8 +250,10 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": - # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. + # To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension... + # for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim @@ -673,12 +675,14 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): elif self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which \ + requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which \ + requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) @@ -776,3 +780,49 @@ def new_LoRACompatibleConv_forward(self, x): diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward + +try: + import xformers + + xformers_available = True +except Exception: + xformers_available = False + + +if xformers_available: + # TODO: remove when fixed in diffusers + _xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention + + def new_memory_efficient_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias=None, + p: float = 0.0, + scale: Optional[float] = None, + *, + op=None, + ): + # diffusers not align shape to 8, which is required by xformers + if attn_bias is not None and type(attn_bias) is torch.Tensor: + orig_size = attn_bias.shape[-1] + new_size = ((orig_size + 7) // 8) * 8 + aligned_attn_bias = torch.zeros( + (attn_bias.shape[0], attn_bias.shape[1], new_size), + device=attn_bias.device, + dtype=attn_bias.dtype, + ) + aligned_attn_bias[:, :, :orig_size] = attn_bias + attn_bias = aligned_attn_bias[:, :, :orig_size] + + return _xformers_memory_efficient_attention( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + p=p, + scale=scale, + op=op, + ) + + xformers.ops.memory_efficient_attention = new_memory_efficient_attention diff --git a/pyproject.toml b/pyproject.toml index 9aef66a35f..129538264d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "albumentations", "click", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", - "compel~=2.0.0", + "compel~=2.0.2", "controlnet-aux>=0.0.6", "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "datasets",