refactor hybrid and cross attention control codepaths for readability

This commit is contained in:
Damian at mba
2022-10-27 19:40:37 +02:00
parent dc86fc92ce
commit f73d349dfe
2 changed files with 112 additions and 80 deletions

View File

@ -114,7 +114,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
conditioning = original_embeddings conditioning = original_embeddings
edited_conditioning = edited_embeddings edited_conditioning = edited_embeddings
print('>> got edit_opcodes', edit_opcodes, 'options', edit_options) #print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = CrossAttentionControl.Arguments( cac_args = CrossAttentionControl.Arguments(
edited_conditioning = edited_conditioning, edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes, edit_opcodes = edit_opcodes,
@ -124,7 +124,13 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens) unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
conditioning = flatten_hybrid_conditioning(unconditioning, conditioning) if isinstance(conditioning, dict):
# hybrid conditioning is in play
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
if cac_args is not None:
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
cac_args = None
return ( return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
cross_attention_control_args=cac_args cross_attention_control_args=cac_args
@ -172,19 +178,17 @@ def flatten_hybrid_conditioning(uncond, cond):
that is a tensor (used by cross attention) vs one that has additional that is a tensor (used by cross attention) vs one that has additional
dimensions as well, as used by 'hybrid' dimensions as well, as used by 'hybrid'
''' '''
if isinstance(cond, dict):
assert isinstance(uncond, dict) assert isinstance(uncond, dict)
cond_in = dict() assert isinstance(cond, dict)
cond_flattened = dict()
for k in cond: for k in cond:
if isinstance(cond[k], list): if isinstance(cond[k], list):
cond_in[k] = [ cond_flattened[k] = [
torch.cat([uncond[k][i], cond[k][i]]) torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k])) for i in range(len(cond[k]))
] ]
else: else:
cond_in[k] = torch.cat([uncond[k], cond[k]]) cond_flattened[k] = torch.cat([uncond[k], cond[k]])
return cond_in return uncond, cond_flattened
else:
return cond

View File

@ -12,7 +12,8 @@ class InvokeAIDiffuserComponent:
all InvokeAI diffusion procedures. all InvokeAI diffusion procedures.
At the moment it includes the following features: At the moment it includes the following features:
* Cross Attention Control ("prompt2prompt") * Cross attention control ("prompt2prompt")
* Hybrid conditioning (used for inpainting)
''' '''
@ -47,51 +48,69 @@ class InvokeAIDiffuserComponent:
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct #todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
#todo: apply edit_options using step_count #todo: apply edit_options using step_count
def remove_cross_attention_control(self): def remove_cross_attention_control(self):
self.conditioning = None self.conditioning = None
self.cross_attention_control_context = None self.cross_attention_control_context = None
CrossAttentionControl.remove_cross_attention_control(self.model) CrossAttentionControl.remove_cross_attention_control(self.model)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor,dict], conditioning: Union[torch.Tensor,dict], unconditioning: Union[torch.Tensor,dict],
conditioning: Union[torch.Tensor,dict],
unconditional_guidance_scale: float, unconditional_guidance_scale: float,
step_index: int=None step_index: Optional[int]=None
): ):
""" """
:param x: Current latents :param x: current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur :param sigma: aka t, passed to the internal model to control how much denoising will occur
:param unconditioning: [B x 77 x 768] embeddings for unconditioned output :param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param conditioning: [B x 77 x 768] embeddings for conditioned output :param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has :param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
:param step_index: Counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. :param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
""" """
CrossAttentionControl.clear_requests(self.model) CrossAttentionControl.clear_requests(self.model)
cross_attention_control_types_to_do = []
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
if step_index is not None: percent_through = self.estimate_percent_through(step_index, sigma)
# percent_through will never reach 1.0 (but this is intended)
percent_through = float(step_index) / float(self.cross_attention_control_context.step_count)
else:
# find the current sigma in the sigma sequence
# todo: this doesn't work with k_dpm_2 because the sigma used jumps around in the sequence
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
# flip because sigmas[0] is for the fully denoised image
# percent_through must be <1
percent_through = 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
#print('estimated percent_through', percent_through, 'from sigma', sigma.item())
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through) cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
if len(cross_attention_control_types_to_do)==0: wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
#print('not doing cross attention control') wants_hybrid_conditioning = isinstance(conditioning, dict)
# faster batched path
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning)
elif wants_cross_attention_control:
unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
else:
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
# to scale how much effect conditioning has, calculate the changes it does and then scale that
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
combined_next_x = unconditioned_next_x + scaled_delta
return combined_next_x
# methods below are called from do_diffusion_step and should be considered private to this class.
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
# fast batched path
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
if isinstance(conditioning, dict): both_conditionings = torch.cat([unconditioning, conditioning])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice,
both_conditionings).chunk(2)
return unconditioned_next_x, conditioned_next_x
def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict) assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = dict() both_conditionings = dict()
for k in conditioning: for k in conditioning:
if isinstance(conditioning[k], list): if isinstance(conditioning[k], list):
@ -101,10 +120,11 @@ class InvokeAIDiffuserComponent:
] ]
else: else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
else:
both_conditionings = torch.cat([unconditioning, conditioning])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
else: return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS) # slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
@ -127,14 +147,22 @@ class InvokeAIDiffuserComponent:
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type) CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
CrossAttentionControl.clear_requests(self.model) CrossAttentionControl.clear_requests(self.model)
return unconditioned_next_x, conditioned_next_x
# to scale how much effect conditioning has, calculate the changes it does and then scale that def estimate_percent_through(self, step_index, sigma):
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale if step_index is not None and self.cross_attention_control_context is not None:
combined_next_x = unconditioned_next_x + scaled_delta # percent_through will never reach 1.0 (but this is intended)
return float(step_index) / float(self.cross_attention_control_context.step_count)
# find the best possible index of the current sigma in the sigma sequence
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
# flip because sigmas[0] is for the fully denoised image
# percent_through must be <1
return 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
return combined_next_x
# todo: make this work # todo: make this work
@classmethod @classmethod