diff --git a/invokeai/app/invocations/control_adapter.py b/invokeai/app/invocations/control_adapter.py index 36053e3b1c..93430654a0 100644 --- a/invokeai/app/invocations/control_adapter.py +++ b/invokeai/app/invocations/control_adapter.py @@ -29,6 +29,7 @@ CONTROLNET_RESIZE_VALUES = Literal[ "just_resize_simple", ] + class ControlNetModelField(BaseModel): """ControlNet model field""" @@ -68,6 +69,7 @@ class ControlField(BaseModel): raise ValueError("Control weights must be within -1 to 2 range") return v + @invocation_output("control_output") class ControlOutput(BaseInvocationOutput): """node output for ControlNet info""" @@ -78,7 +80,6 @@ class ControlOutput(BaseInvocationOutput): control: ControlField = OutputField(description=FieldDescriptions.control) - @invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet") class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" @@ -119,19 +120,21 @@ class ControlNetInvocation(BaseInvocation): ), ) + IP_ADAPTER_MODELS = Literal[ "models_ip_adapter/models/ip-adapter_sd15.bin", "models_ip_adapter/models/ip-adapter-plus_sd15.bin", "models_ip_adapter/models/ip-adapter-plus-face_sd15.bin", - "models_ip_adapter/sdxl_models/ip-adapter_sdxl.bin" + "models_ip_adapter/sdxl_models/ip-adapter_sdxl.bin", ] IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[ "models_ip_adapter/models/image_encoder/", "./models_ip_adapter/models/image_encoder/", - "models_ip_adapter/sdxl_models/image_encoder/" + "models_ip_adapter/sdxl_models/image_encoder/", ] + @invocation("ipadapter", title="IP-Adapter", tags=["ipadapter"], category="ipadapter") class IPAdapterInvocation(BaseInvocation): """Collects IP-Adapter info to pass to other nodes""" @@ -140,14 +143,15 @@ class IPAdapterInvocation(BaseInvocation): # Inputs image: ImageField = InputField(description="The control image") - #control_model: ControlNetModelField = InputField( + # control_model: ControlNetModelField = InputField( # default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct - #) - ip_adapter_model: IP_ADAPTER_MODELS = InputField(default="./models_ip_adapter/models/ip-adapter_sd15.bin", - description="The IP-Adapter model") + # ) + ip_adapter_model: IP_ADAPTER_MODELS = InputField( + default="./models_ip_adapter/models/ip-adapter_sd15.bin", description="The IP-Adapter model" + ) image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField( - default="./models_ip_adapter/models/image_encoder/", - description="The image encoder model") + default="./models_ip_adapter/models/image_encoder/", description="The image encoder model" + ) control_weight: Union[float, List[float]] = InputField( default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float ) @@ -172,9 +176,9 @@ class IPAdapterInvocation(BaseInvocation): image_encoder_model=self.image_encoder_model, control_weight=self.control_weight, # rest are currently ignored - #begin_step_percent=self.begin_step_percent, - #end_step_percent=self.end_step_percent, - #control_mode=self.control_mode, - #resize_mode=self.resize_mode, + # begin_step_percent=self.begin_step_percent, + # end_step_percent=self.end_step_percent, + # control_mode=self.control_mode, + # resize_mode=self.resize_mode, ), ) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 8edb69bfcf..cca05c5700 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -1,7 +1,7 @@ # Invocations for ControlNet image preprocessors # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux from builtins import bool, float -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Optional import cv2 import numpy as np @@ -27,17 +27,7 @@ from PIL import Image from invokeai.app.invocations.primitives import ImageField, ImageOutput from ..models.image import ImageCategory, ResourceOrigin -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - FieldDescriptions, - Input, - InputField, - InvocationContext, - OutputField, - UIType, - invocation, -) +from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation @invocation( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 37e4c5955b..2252dcee8f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -65,7 +65,6 @@ from .control_adapter import ControlField from .model import ModelInfo, UNetField, VaeField - DEFAULT_PRECISION = choose_precision(choose_torch_device()) SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] @@ -387,7 +386,7 @@ class DenoiseLatentsInvocation(BaseInvocation): resize_mode=control_info.resize_mode, ) control_item = ControlNetData( - model=control_model, # model object + model=control_model, # model object image_tensor=control_image, weight=control_info.control_weight, begin_step_percent=control_info.begin_step_percent, @@ -404,7 +403,7 @@ class DenoiseLatentsInvocation(BaseInvocation): input_image = context.services.images.get_pil_image(control_image_field.image_name) control_item = IPAdapterData( ip_adapter_model=control_info.ip_adapter_model, # name of model (NOT model object) - image_encoder_model=control_info.image_encoder_model, # name of model (NOT model obj) + image_encoder_model=control_info.image_encoder_model, # name of model (NOT model obj) image=input_image, weight=control_info.control_weight, ) @@ -564,8 +563,8 @@ class DenoiseLatentsInvocation(BaseInvocation): conditioning_data=conditioning_data, control_data=controlnet_data, # list[ControlNetData], ip_adapter_data=ip_adapter_data, # list[IPAdapterData], -# ip_adapter_image=unwrapped_ip_adapter_image, -# ip_adapter_strength=self.ip_adapter_strength, + # ip_adapter_image=unwrapped_ip_adapter_image, + # ip_adapter_strength=self.ip_adapter_strength, callback=step_callback, ) diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py index de9b367b7d..c3fd1b8bb1 100644 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ b/invokeai/backend/ip_adapter/attention_processor.py @@ -12,6 +12,7 @@ class AttnProcessor(nn.Module): r""" Default processor for performing attention-related computations. """ + def __init__( self, hidden_size=None, @@ -140,7 +141,10 @@ class IPAttnProcessor(nn.Module): encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) # split hidden states - encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, : self.text_context_len, :], + encoder_hidden_states[:, self.text_context_len :, :], + ) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -186,6 +190,7 @@ class AttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ + def __init__( self, hidden_size=None, @@ -338,7 +343,10 @@ class IPAttnProcessor2_0(torch.nn.Module): encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) # split hidden states - encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, : self.text_context_len, :], + encoder_hidden_states[:, self.text_context_len :, :], + ) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index ddec16eebc..81bc6db847 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -22,6 +22,7 @@ from .resampler import Resampler class ImageProjModel(torch.nn.Module): """Projection Model""" + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() @@ -32,15 +33,15 @@ class ImageProjModel(torch.nn.Module): def forward(self, image_embeds): embeds = image_embeds - clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens class IPAdapter: - def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): - self.device = device self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt @@ -54,7 +55,9 @@ class IPAdapter: self.set_ip_adapter() # load image encoder - self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16) + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.float16 + ) self.clip_image_processor = CLIPImageProcessor() # image proj model self.image_proj_model = self.init_proj() @@ -88,8 +91,9 @@ class IPAdapter: attn_procs[name] = AttnProcessor() else: print("swapping in IPAttnProcessor for", name) - attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, - scale=1.0).to(self.device, dtype=torch.float16) + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + ).to(self.device, dtype=torch.float16) unet.set_attn_processor(attn_procs) print("Modified UNet Attn Processors count:", len(unet.attn_processors)) print(unet.attn_processors.keys()) @@ -155,7 +159,12 @@ class IPAdapter: with torch.inference_mode(): prompt_embeds = self.pipe._encode_prompt( - prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2) prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) @@ -212,8 +221,17 @@ class IPAdapterXL(IPAdapter): uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) with torch.inference_mode(): - prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt( - prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) @@ -243,7 +261,7 @@ class IPAdapterPlus(IPAdapter): num_queries=self.num_tokens, embedding_dim=self.image_encoder.config.hidden_size, output_dim=self.pipe.unet.config.cross_attention_dim, - ff_mult=4 + ff_mult=4, ).to(self.device, dtype=torch.float16) return image_proj_model @@ -255,6 +273,8 @@ class IPAdapterPlus(IPAdapter): clip_image = clip_image.to(self.device, dtype=torch.float16) clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] image_prompt_embeds = self.image_proj_model(clip_image_embeds) - uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2] + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) return image_prompt_embeds, uncond_image_prompt_embeds diff --git a/invokeai/backend/ip_adapter/resampler.py b/invokeai/backend/ip_adapter/resampler.py index 327ef7c140..38c4b06dcf 100644 --- a/invokeai/backend/ip_adapter/resampler.py +++ b/invokeai/backend/ip_adapter/resampler.py @@ -20,7 +20,7 @@ def FeedForward(dim, mult=4): def reshape_tensor(x, heads): bs, length, width = x.shape - #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, length, heads, -1) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) @@ -44,7 +44,6 @@ class PerceiverAttention(nn.Module): self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) - def forward(self, x, latents): """ Args: @@ -68,7 +67,7 @@ class PerceiverAttention(nn.Module): # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v @@ -110,7 +109,6 @@ class Resampler(nn.Module): ) def forward(self, x): - latents = self.latents.repeat(x.size(0), 1, 1) x = self.proj_in(x) diff --git a/invokeai/backend/ip_adapter/utils.py b/invokeai/backend/ip_adapter/utils.py index e120a9e2b4..049d8163c2 100644 --- a/invokeai/backend/ip_adapter/utils.py +++ b/invokeai/backend/ip_adapter/utils.py @@ -15,7 +15,6 @@ from diffusers.models import ControlNetModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput - def is_torch2_available(): return hasattr(F, "scaled_dot_product_attention") @@ -150,9 +149,7 @@ def generate( control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ - control_guidance_end - ] + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [control_guidance_end] # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -192,9 +189,7 @@ def generate( guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) + text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None prompt_embeds = self._encode_prompt( prompt, device, diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index bdc6a9193d..891a9b9fb0 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -179,6 +179,7 @@ class IPAdapterData: # weight: Union[float, List[float]] = Field(default=1.0) weight: float = Field(default=1.0) + @dataclass class ConditioningData: unconditioned_embeddings: BasicConditioningInfo @@ -442,7 +443,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ip_adapter_data: List[IPAdapterData] = None, callback: Callable[[PipelineIntermediateState], None] = None, ): - self._adjust_memory_efficient_attention(latents) if additional_guidance is None: additional_guidance = [] @@ -469,30 +469,33 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # if "sdxl" in ip_adapter_info.ip_adapter_model: print("using IPAdapterXL") - ip_adapter = IPAdapterXL(self, - ip_adapter_info.image_encoder_model, - ip_adapter_info.ip_adapter_model, - self.unet.device) + ip_adapter = IPAdapterXL( + self, ip_adapter_info.image_encoder_model, ip_adapter_info.ip_adapter_model, self.unet.device + ) elif "plus" in ip_adapter_info.ip_adapter_model: print("using IPAdapterPlus") - ip_adapter = IPAdapterPlus(self, # IPAdapterPlus first arg is StableDiffusionPipeline - ip_adapter_info.image_encoder_model, - ip_adapter_info.ip_adapter_model, - self.unet.device, - num_tokens=16) + ip_adapter = IPAdapterPlus( + self, # IPAdapterPlus first arg is StableDiffusionPipeline + ip_adapter_info.image_encoder_model, + ip_adapter_info.ip_adapter_model, + self.unet.device, + num_tokens=16, + ) else: print("using IPAdapter") - ip_adapter = IPAdapter(self, # IPAdapter first arg is StableDiffusionPipeline - ip_adapter_info.image_encoder_model, - ip_adapter_info.ip_adapter_model, - self.unet.device) + ip_adapter = IPAdapter( + self, # IPAdapter first arg is StableDiffusionPipeline + ip_adapter_info.image_encoder_model, + ip_adapter_info.ip_adapter_model, + self.unet.device, + ) # IP-Adapter ==> add additional cross-attention layers to UNet model here? ip_adapter.set_scale(ip_adapter_info.weight) print("ip_adapter:", ip_adapter) # get image embedding from CLIP and ImageProjModel print("getting image embeddings from IP-Adapter...") - num_samples = 1 # hardwiring for first pass + num_samples = 1 # hardwiring for first pass image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_image) print("image cond embeds shape:", image_prompt_embeds.shape) print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)