mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove IP-Adapter and T2I-Adapter support from MultiDiffusionPipeline.
This commit is contained in:
parent
889d13e02a
commit
20322d781e
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
is_inpainting_model,
|
is_inpainting_model,
|
||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||||
|
|
||||||
|
|
||||||
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||||
@ -63,6 +63,11 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
masked_latents: Optional[torch.Tensor] = None,
|
masked_latents: Optional[torch.Tensor] = None,
|
||||||
is_gradient_mask: bool = False,
|
is_gradient_mask: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if ip_adapter_data is not None:
|
||||||
|
raise NotImplementedError("ip_adapter_data is not supported in MultiDiffusionPipeline")
|
||||||
|
if t2i_adapter_data is not None:
|
||||||
|
raise NotImplementedError("t2i_adapter_data is not supported in MultiDiffusionPipeline")
|
||||||
|
|
||||||
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
||||||
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
||||||
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
|
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
|
||||||
@ -106,20 +111,14 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
is_gradient_mask=is_gradient_mask,
|
is_gradient_mask=is_gradient_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
use_ip_adapter = ip_adapter_data is not None
|
|
||||||
use_regional_prompting = (
|
use_regional_prompting = (
|
||||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||||
)
|
)
|
||||||
unet_attention_patcher = None
|
unet_attention_patcher = None
|
||||||
attn_ctx = nullcontext()
|
attn_ctx = nullcontext()
|
||||||
|
|
||||||
if use_ip_adapter or use_regional_prompting:
|
if use_regional_prompting:
|
||||||
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
unet_attention_patcher = UNetAttentionPatcher(ip_adapter_data=None)
|
||||||
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
|
|
||||||
if use_ip_adapter
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
|
||||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
|
|
||||||
with attn_ctx:
|
with attn_ctx:
|
||||||
@ -146,8 +145,6 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
mask=mask,
|
mask=mask,
|
||||||
masked_latents=masked_latents,
|
masked_latents=masked_latents,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
|
Loading…
Reference in New Issue
Block a user