mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Clean-up code a bit
This commit is contained in:
parent
59ba9fc0f6
commit
7a8f14d595
@ -34,6 +34,7 @@ from .diffusion import (
|
|||||||
AttentionMapSaver,
|
AttentionMapSaver,
|
||||||
InvokeAIDiffuserComponent,
|
InvokeAIDiffuserComponent,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
|
BasicConditioningInfo,
|
||||||
)
|
)
|
||||||
from ..util import normalize_device
|
from ..util import normalize_device
|
||||||
|
|
||||||
@ -92,8 +93,7 @@ class AddsMaskGuidance:
|
|||||||
mask: torch.FloatTensor
|
mask: torch.FloatTensor
|
||||||
mask_latents: torch.FloatTensor
|
mask_latents: torch.FloatTensor
|
||||||
scheduler: SchedulerMixin
|
scheduler: SchedulerMixin
|
||||||
noise: Optional[torch.Tensor] = None
|
noise: torch.Tensor
|
||||||
_debug: Optional[Callable] = None
|
|
||||||
|
|
||||||
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
|
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
|
||||||
output_class = step_output.__class__ # We'll create a new one with masked data.
|
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||||
@ -123,14 +123,13 @@ class AddsMaskGuidance:
|
|||||||
# some schedulers expect t to be one-dimensional.
|
# some schedulers expect t to be one-dimensional.
|
||||||
# TODO: file diffusers bug about inconsistency?
|
# TODO: file diffusers bug about inconsistency?
|
||||||
t = einops.repeat(t, "-> batch", batch=batch_size)
|
t = einops.repeat(t, "-> batch", batch=batch_size)
|
||||||
|
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||||
if self.noise is not None:
|
# get very confused about what is happening from step to step when we do that.
|
||||||
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
||||||
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||||
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||||
if self._debug:
|
|
||||||
self._debug(masked_input, f"t={t} lerped")
|
|
||||||
return masked_input
|
return masked_input
|
||||||
|
|
||||||
|
|
||||||
@ -202,8 +201,8 @@ class ControlNetData:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: Any # TODO: type
|
unconditioned_embeddings: BasicConditioningInfo
|
||||||
text_embeddings: Any # TODO: type
|
text_embeddings: BasicConditioningInfo
|
||||||
guidance_scale: Union[float, List[float]]
|
guidance_scale: Union[float, List[float]]
|
||||||
"""
|
"""
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
@ -3,4 +3,4 @@ Initialization file for invokeai.models.diffusion
|
|||||||
"""
|
"""
|
||||||
from .cross_attention_control import InvokeAICrossAttentionMixin
|
from .cross_attention_control import InvokeAICrossAttentionMixin
|
||||||
from .cross_attention_map_saving import AttentionMapSaver
|
from .cross_attention_map_saving import AttentionMapSaver
|
||||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo
|
||||||
|
Loading…
Reference in New Issue
Block a user