mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix - encoder_attention_mask not passed before to unet, even if passed it will broke sequential guidance run, so rewrite logic
This commit is contained in:
parent
1d3fda80aa
commit
1c680a7147
@ -507,40 +507,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
|
||||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
|
||||||
|
|
||||||
if cond.shape[1] < max_len:
|
|
||||||
conditioning_attention_mask = torch.cat([
|
|
||||||
conditioning_attention_mask,
|
|
||||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
|
||||||
], dim=1)
|
|
||||||
|
|
||||||
cond = torch.cat([
|
|
||||||
cond,
|
|
||||||
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
|
||||||
], dim=1)
|
|
||||||
|
|
||||||
if encoder_attention_mask is None:
|
|
||||||
encoder_attention_mask = conditioning_attention_mask
|
|
||||||
else:
|
|
||||||
encoder_attention_mask = torch.cat([
|
|
||||||
encoder_attention_mask,
|
|
||||||
conditioning_attention_mask,
|
|
||||||
])
|
|
||||||
|
|
||||||
return cond, encoder_attention_mask
|
|
||||||
|
|
||||||
encoder_attention_mask = None
|
|
||||||
if conditioning_data.unconditioned_embeddings.shape[1] != conditioning_data.text_embeddings.shape[1]:
|
|
||||||
max_len = max(conditioning_data.unconditioned_embeddings.shape[1], conditioning_data.text_embeddings.shape[1])
|
|
||||||
conditioning_data.unconditioned_embeddings, encoder_attention_mask = _pad_conditioning(
|
|
||||||
conditioning_data.unconditioned_embeddings, max_len, encoder_attention_mask
|
|
||||||
)
|
|
||||||
conditioning_data.text_embeddings, encoder_attention_mask = _pad_conditioning(
|
|
||||||
conditioning_data.text_embeddings, max_len, encoder_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||||
@ -580,7 +546,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
@ -638,8 +603,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
if control_data is not None:
|
if control_data is not None:
|
||||||
# TODO: rewrite to pass with conditionings
|
|
||||||
encoder_attention_mask = kwargs.get("encoder_attention_mask", None)
|
|
||||||
# control_data should be type List[ControlNetData]
|
# control_data should be type List[ControlNetData]
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
# and MultiControlNet (multiple ControlNetData in list)
|
# and MultiControlNet (multiple ControlNetData in list)
|
||||||
@ -669,9 +632,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
encoder_hidden_states = conditioning_data.text_embeddings
|
encoder_hidden_states = conditioning_data.text_embeddings
|
||||||
|
encoder_attention_mask = None
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
|
encoder_hidden_states, encoder_hidden_states = self.invokeai_diffuser._concat_conditionings_for_batch(
|
||||||
conditioning_data.text_embeddings])
|
conditioning_data.unconditioned_embeddings,
|
||||||
|
conditioning_data.text_embeddings,
|
||||||
|
)
|
||||||
if isinstance(control_datum.weight, list):
|
if isinstance(control_datum.weight, list):
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
controlnet_weight = control_datum.weight[step_index]
|
controlnet_weight = control_datum.weight[step_index]
|
||||||
|
@ -237,6 +237,39 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||||
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||||
|
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||||
|
|
||||||
|
if cond.shape[1] < max_len:
|
||||||
|
conditioning_attention_mask = torch.cat([
|
||||||
|
conditioning_attention_mask,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
cond = torch.cat([
|
||||||
|
cond,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = conditioning_attention_mask
|
||||||
|
else:
|
||||||
|
encoder_attention_mask = torch.cat([
|
||||||
|
encoder_attention_mask,
|
||||||
|
conditioning_attention_mask,
|
||||||
|
])
|
||||||
|
|
||||||
|
return cond, encoder_attention_mask
|
||||||
|
|
||||||
|
encoder_attention_mask = None
|
||||||
|
if unconditioning.shape[1] != conditioning.shape[1]:
|
||||||
|
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
||||||
|
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||||
|
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||||
|
|
||||||
|
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
@ -244,9 +277,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
|
unconditioning, conditioning
|
||||||
|
)
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
x_twice, sigma_twice, both_conditionings,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
Loading…
x
Reference in New Issue
Block a user