mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix avoid storing extra conditioning info in two places.
This commit is contained in:
parent
8fb297e5f6
commit
4a1acd4db9
@ -360,7 +360,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
) -> ConditioningData:
|
) -> ConditioningData:
|
||||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
extra_conditioning_info = c.extra_conditioning
|
|
||||||
|
|
||||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
@ -370,7 +369,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
text_embeddings=c,
|
text_embeddings=c,
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
extra=extra_conditioning_info,
|
|
||||||
postprocessing_settings=PostprocessingSettings(
|
postprocessing_settings=PostprocessingSettings(
|
||||||
threshold=0.0, # threshold,
|
threshold=0.0, # threshold,
|
||||||
warmup=0.2, # warmup,
|
warmup=0.2, # warmup,
|
||||||
|
@ -427,10 +427,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
return latents, attention_map_saver
|
return latents, attention_map_saver
|
||||||
|
|
||||||
ip_adapter_unet_patcher = None
|
ip_adapter_unet_patcher = None
|
||||||
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
|
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
||||||
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=conditioning_data.extra,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
step_count=len(self.scheduler.timesteps),
|
step_count=len(self.scheduler.timesteps),
|
||||||
)
|
)
|
||||||
self.use_ip_adapter = False
|
self.use_ip_adapter = False
|
||||||
|
@ -21,11 +21,7 @@ class ExtraConditioningInfo:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BasicConditioningInfo:
|
class BasicConditioningInfo:
|
||||||
embeds: torch.Tensor
|
embeds: torch.Tensor
|
||||||
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
|
|
||||||
# should only be stored in one place.
|
|
||||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||||
# weight: float
|
|
||||||
# mode: ConditioningAlgo
|
|
||||||
|
|
||||||
def to(self, device, dtype=None):
|
def to(self, device, dtype=None):
|
||||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||||
@ -83,7 +79,6 @@ class ConditioningData:
|
|||||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
||||||
"""
|
"""
|
||||||
guidance_rescale_multiplier: float = 0
|
guidance_rescale_multiplier: float = 0
|
||||||
extra: Optional[ExtraConditioningInfo] = None
|
|
||||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||||
"""
|
"""
|
||||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||||
|
Loading…
Reference in New Issue
Block a user