Fix avoid storing extra conditioning info in two places.

This commit is contained in:
Ryan Dick 2024-02-13 17:50:20 -05:00
parent 273994b742
commit 16e574825c
3 changed files with 4 additions and 10 deletions

View File

@ -226,7 +226,7 @@ def get_scheduler(
class DenoiseLatentsInvocation(BaseInvocation): class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images""" """Denoises noisy latents to decodable images"""
positive_conditioning: ConditioningField = InputField( positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0 description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
) )
negative_conditioning: ConditioningField = InputField( negative_conditioning: ConditioningField = InputField(
@ -332,7 +332,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
) -> ConditioningData: ) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = c.extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
@ -342,7 +341,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embeddings=c, text_embeddings=c,
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier, guidance_rescale_multiplier=self.cfg_rescale_multiplier,
extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings( postprocessing_settings=PostprocessingSettings(
threshold=0.0, # threshold, threshold=0.0, # threshold,
warmup=0.2, # warmup, warmup=0.2, # warmup,

View File

@ -420,10 +420,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return latents, attention_map_saver return latents, attention_map_saver
ip_adapter_unet_patcher = None ip_adapter_unet_patcher = None
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control: extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context( attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model, self.invokeai_diffuser.model,
extra_conditioning_info=conditioning_data.extra, extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps), step_count=len(self.scheduler.timesteps),
) )
self.use_ip_adapter = False self.use_ip_adapter = False

View File

@ -21,11 +21,7 @@ class ExtraConditioningInfo:
@dataclass @dataclass
class BasicConditioningInfo: class BasicConditioningInfo:
embeds: torch.Tensor embeds: torch.Tensor
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
# should only be stored in one place.
extra_conditioning: Optional[ExtraConditioningInfo] extra_conditioning: Optional[ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None): def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype) self.embeds = self.embeds.to(device=device, dtype=dtype)
@ -78,7 +74,6 @@ class ConditioningData:
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
""" """
guidance_rescale_multiplier: float = 0 guidance_rescale_multiplier: float = 0
extra: Optional[ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict) scheduler_args: dict[str, Any] = field(default_factory=dict)
""" """
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing(). Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().