Split ControlField and IpAdapterField.

This commit is contained in:
Ryan Dick
2023-09-06 13:36:00 -04:00
parent 94ec3da7b5
commit d776e0a0a9
10 changed files with 256 additions and 204 deletions

View File

@ -13,8 +13,12 @@ import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils.import_utils import is_xformers_available
@ -26,7 +30,12 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
from ..util import auto_detect_slice_size, normalize_device
from .diffusion import AttentionMapSaver, BasicConditioningInfo, InvokeAIDiffuserComponent, PostprocessingSettings
from .diffusion import (
AttentionMapSaver,
BasicConditioningInfo,
InvokeAIDiffuserComponent,
PostprocessingSettings,
)
@dataclass
@ -96,7 +105,7 @@ class AddsMaskGuidance:
# Mask anything that has the same shape as prev_sample, return others as-is.
return output_class(
{
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v)
k: self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v
for k, v in step_output.items()
}
)
@ -360,7 +369,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: IPAdapterData = None,
ip_adapter_data: Optional[IPAdapterData] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
@ -432,7 +441,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: List[IPAdapterData] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
):
self._adjust_memory_efficient_attention(latents)
@ -445,12 +454,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0:
return latents, attention_map_saver
# print("ip_adapter_image: ", type(ip_adapter_image))
if ip_adapter_data is not None and len(ip_adapter_data) > 0:
ip_adapter_info = ip_adapter_data[0]
ip_adapter_image = ip_adapter_info.image
# initialize IPAdapter
print(" width:", ip_adapter_image.width, " height:", ip_adapter_image.height)
if ip_adapter_data is not None:
# Initialize IPAdapter
# FIXME:
# WARNING!
# IPAdapter constructor modifies UNet model in-place
@ -459,17 +464,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# and how to undo if ip_adapter_image is removed
# Should reimplement to use existing model management context etc.
#
if "sdxl" in ip_adapter_info.ip_adapter_model:
if "sdxl" in ip_adapter_data.ip_adapter_model:
print("using IPAdapterXL")
ip_adapter = IPAdapterXL(
self, ip_adapter_info.image_encoder_model, ip_adapter_info.ip_adapter_model, self.unet.device
self, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
)
elif "plus" in ip_adapter_info.ip_adapter_model:
elif "plus" in ip_adapter_data.ip_adapter_model:
print("using IPAdapterPlus")
ip_adapter = IPAdapterPlus(
self, # IPAdapterPlus first arg is StableDiffusionPipeline
ip_adapter_info.image_encoder_model,
ip_adapter_info.ip_adapter_model,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
num_tokens=16,
)
@ -477,18 +482,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
print("using IPAdapter")
ip_adapter = IPAdapter(
self, # IPAdapter first arg is StableDiffusionPipeline
ip_adapter_info.image_encoder_model,
ip_adapter_info.ip_adapter_model,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
)
# IP-Adapter ==> add additional cross-attention layers to UNet model here?
ip_adapter.set_scale(ip_adapter_info.weight)
ip_adapter.set_scale(ip_adapter_data.weight)
print("ip_adapter:", ip_adapter)
# get image embedding from CLIP and ImageProjModel
print("getting image embeddings from IP-Adapter...")
num_samples = 1 # hardwiring for first pass
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_image)
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
print("image cond embeds shape:", image_prompt_embeds.shape)
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
bs_embed, seq_len, _ = image_prompt_embeds.shape