mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix/long prompts (#3806)
This commit is contained in:
commit
47b1a85e70
@ -422,7 +422,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
**kwargs,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@ -443,7 +442,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
**kwargs,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
@ -469,7 +467,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
@ -487,11 +484,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
noise=noise,
|
||||
additional_guidance=additional_guidance,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
**kwargs,
|
||||
|
||||
callback=callback,
|
||||
)
|
||||
return result.latents, result.attention_map_saver
|
||||
|
||||
@ -505,42 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
):
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||
|
||||
if cond.shape[1] < max_len:
|
||||
conditioning_attention_mask = torch.cat([
|
||||
conditioning_attention_mask,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
|
||||
cond = torch.cat([
|
||||
cond,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = conditioning_attention_mask
|
||||
else:
|
||||
encoder_attention_mask = torch.cat([
|
||||
encoder_attention_mask,
|
||||
conditioning_attention_mask,
|
||||
])
|
||||
|
||||
return cond, encoder_attention_mask
|
||||
|
||||
encoder_attention_mask = None
|
||||
if conditioning_data.unconditioned_embeddings.shape[1] != conditioning_data.text_embeddings.shape[1]:
|
||||
max_len = max(conditioning_data.unconditioned_embeddings.shape[1], conditioning_data.text_embeddings.shape[1])
|
||||
conditioning_data.unconditioned_embeddings, encoder_attention_mask = _pad_conditioning(
|
||||
conditioning_data.unconditioned_embeddings, max_len, encoder_attention_mask
|
||||
)
|
||||
conditioning_data.text_embeddings, encoder_attention_mask = _pad_conditioning(
|
||||
conditioning_data.text_embeddings, max_len, encoder_attention_mask
|
||||
)
|
||||
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
@ -580,8 +542,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
@ -623,7 +583,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
total_step_count: int,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
@ -638,8 +597,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
if control_data is not None:
|
||||
# TODO: rewrite to pass with conditionings
|
||||
encoder_attention_mask = kwargs.get("encoder_attention_mask", None)
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (one ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
@ -669,9 +626,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||
encoder_hidden_states = conditioning_data.text_embeddings
|
||||
encoder_attention_mask = None
|
||||
else:
|
||||
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings])
|
||||
encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings,
|
||||
)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
controlnet_weight = control_datum.weight[step_index]
|
||||
|
@ -237,6 +237,39 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
return latents
|
||||
|
||||
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||
|
||||
if cond.shape[1] < max_len:
|
||||
conditioning_attention_mask = torch.cat([
|
||||
conditioning_attention_mask,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
|
||||
cond = torch.cat([
|
||||
cond,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = conditioning_attention_mask
|
||||
else:
|
||||
encoder_attention_mask = torch.cat([
|
||||
encoder_attention_mask,
|
||||
conditioning_attention_mask,
|
||||
])
|
||||
|
||||
return cond, encoder_attention_mask
|
||||
|
||||
encoder_attention_mask = None
|
||||
if unconditioning.shape[1] != conditioning.shape[1]:
|
||||
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
||||
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||
|
||||
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
@ -244,9 +277,13 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
unconditioning, conditioning
|
||||
)
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||
x_twice, sigma_twice, both_conditionings,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
@ -260,8 +297,32 @@ class InvokeAIDiffuserComponent:
|
||||
**kwargs,
|
||||
):
|
||||
# low-memory sequential path
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
||||
uncond_down_block, cond_down_block = None, None
|
||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||
if down_block_additional_residuals is not None:
|
||||
uncond_down_block, cond_down_block = [], []
|
||||
for down_block in down_block_additional_residuals:
|
||||
_uncond_down, _cond_down = down_block.chunk(2)
|
||||
uncond_down_block.append(_uncond_down)
|
||||
cond_down_block.append(_cond_down)
|
||||
|
||||
uncond_mid_block, cond_mid_block = None, None
|
||||
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, unconditioning,
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
**kwargs,
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, conditioning,
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
**kwargs,
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
# TODO: looks unused
|
||||
@ -295,6 +356,20 @@ class InvokeAIDiffuserComponent:
|
||||
):
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
uncond_down_block, cond_down_block = None, None
|
||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||
if down_block_additional_residuals is not None:
|
||||
uncond_down_block, cond_down_block = [], []
|
||||
for down_block in down_block_additional_residuals:
|
||||
_uncond_down, _cond_down = down_block.chunk(2)
|
||||
uncond_down_block.append(_uncond_down)
|
||||
cond_down_block.append(_cond_down)
|
||||
|
||||
uncond_mid_block, cond_mid_block = None, None
|
||||
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
cross_attn_processor_context = SwapCrossAttnContext(
|
||||
modified_text_embeddings=context.arguments.edited_conditioning,
|
||||
index_map=context.cross_attention_index_map,
|
||||
@ -307,6 +382,8 @@ class InvokeAIDiffuserComponent:
|
||||
sigma,
|
||||
unconditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -319,6 +396,8 @@ class InvokeAIDiffuserComponent:
|
||||
sigma,
|
||||
conditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
**kwargs,
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
Loading…
Reference in New Issue
Block a user