mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor hybrid and cross attention control codepaths for readability
This commit is contained in:
@ -114,7 +114,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
|
|
||||||
conditioning = original_embeddings
|
conditioning = original_embeddings
|
||||||
edited_conditioning = edited_embeddings
|
edited_conditioning = edited_embeddings
|
||||||
print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||||
cac_args = CrossAttentionControl.Arguments(
|
cac_args = CrossAttentionControl.Arguments(
|
||||||
edited_conditioning = edited_conditioning,
|
edited_conditioning = edited_conditioning,
|
||||||
edit_opcodes = edit_opcodes,
|
edit_opcodes = edit_opcodes,
|
||||||
@ -124,7 +124,13 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
|
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
|
||||||
|
|
||||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
|
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
|
||||||
conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
|
if isinstance(conditioning, dict):
|
||||||
|
# hybrid conditioning is in play
|
||||||
|
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||||
|
if cac_args is not None:
|
||||||
|
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
||||||
|
cac_args = None
|
||||||
|
|
||||||
return (
|
return (
|
||||||
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
cross_attention_control_args=cac_args
|
cross_attention_control_args=cac_args
|
||||||
@ -172,19 +178,17 @@ def flatten_hybrid_conditioning(uncond, cond):
|
|||||||
that is a tensor (used by cross attention) vs one that has additional
|
that is a tensor (used by cross attention) vs one that has additional
|
||||||
dimensions as well, as used by 'hybrid'
|
dimensions as well, as used by 'hybrid'
|
||||||
'''
|
'''
|
||||||
if isinstance(cond, dict):
|
|
||||||
assert isinstance(uncond, dict)
|
assert isinstance(uncond, dict)
|
||||||
cond_in = dict()
|
assert isinstance(cond, dict)
|
||||||
|
cond_flattened = dict()
|
||||||
for k in cond:
|
for k in cond:
|
||||||
if isinstance(cond[k], list):
|
if isinstance(cond[k], list):
|
||||||
cond_in[k] = [
|
cond_flattened[k] = [
|
||||||
torch.cat([uncond[k][i], cond[k][i]])
|
torch.cat([uncond[k][i], cond[k][i]])
|
||||||
for i in range(len(cond[k]))
|
for i in range(len(cond[k]))
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
cond_in[k] = torch.cat([uncond[k], cond[k]])
|
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
|
||||||
return cond_in
|
return uncond, cond_flattened
|
||||||
else:
|
|
||||||
return cond
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
all InvokeAI diffusion procedures.
|
all InvokeAI diffusion procedures.
|
||||||
|
|
||||||
At the moment it includes the following features:
|
At the moment it includes the following features:
|
||||||
* Cross Attention Control ("prompt2prompt")
|
* Cross attention control ("prompt2prompt")
|
||||||
|
* Hybrid conditioning (used for inpainting)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
@ -47,51 +48,69 @@ class InvokeAIDiffuserComponent:
|
|||||||
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
|
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
|
||||||
#todo: apply edit_options using step_count
|
#todo: apply edit_options using step_count
|
||||||
|
|
||||||
|
|
||||||
def remove_cross_attention_control(self):
|
def remove_cross_attention_control(self):
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
CrossAttentionControl.remove_cross_attention_control(self.model)
|
CrossAttentionControl.remove_cross_attention_control(self.model)
|
||||||
|
|
||||||
|
|
||||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||||
unconditioning: Union[torch.Tensor,dict], conditioning: Union[torch.Tensor,dict],
|
unconditioning: Union[torch.Tensor,dict],
|
||||||
|
conditioning: Union[torch.Tensor,dict],
|
||||||
unconditional_guidance_scale: float,
|
unconditional_guidance_scale: float,
|
||||||
step_index: int=None
|
step_index: Optional[int]=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param x: Current latents
|
:param x: current latents
|
||||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||||
:param unconditioning: [B x 77 x 768] embeddings for unconditioned output
|
:param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||||
:param conditioning: [B x 77 x 768] embeddings for conditioned output
|
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||||
:param step_index: Counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase.
|
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
||||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CrossAttentionControl.clear_requests(self.model)
|
CrossAttentionControl.clear_requests(self.model)
|
||||||
cross_attention_control_types_to_do = []
|
|
||||||
|
|
||||||
|
cross_attention_control_types_to_do = []
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
if step_index is not None:
|
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||||
# percent_through will never reach 1.0 (but this is intended)
|
|
||||||
percent_through = float(step_index) / float(self.cross_attention_control_context.step_count)
|
|
||||||
else:
|
|
||||||
# find the current sigma in the sigma sequence
|
|
||||||
# todo: this doesn't work with k_dpm_2 because the sigma used jumps around in the sequence
|
|
||||||
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
|
|
||||||
# flip because sigmas[0] is for the fully denoised image
|
|
||||||
# percent_through must be <1
|
|
||||||
percent_through = 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
|
|
||||||
#print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
|
||||||
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
|
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
|
||||||
|
|
||||||
if len(cross_attention_control_types_to_do)==0:
|
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||||
#print('not doing cross attention control')
|
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||||
# faster batched path
|
|
||||||
x_twice = torch.cat([x]*2)
|
if wants_hybrid_conditioning:
|
||||||
sigma_twice = torch.cat([sigma]*2)
|
unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning)
|
||||||
if isinstance(conditioning, dict):
|
elif wants_cross_attention_control:
|
||||||
|
unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
||||||
|
else:
|
||||||
|
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
|
||||||
|
|
||||||
|
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||||
|
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
|
||||||
|
combined_next_x = unconditioned_next_x + scaled_delta
|
||||||
|
|
||||||
|
return combined_next_x
|
||||||
|
|
||||||
|
|
||||||
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
|
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||||
|
# fast batched path
|
||||||
|
x_twice = torch.cat([x] * 2)
|
||||||
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
|
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice,
|
||||||
|
both_conditionings).chunk(2)
|
||||||
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
|
def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||||
|
assert isinstance(conditioning, dict)
|
||||||
assert isinstance(unconditioning, dict)
|
assert isinstance(unconditioning, dict)
|
||||||
|
x_twice = torch.cat([x] * 2)
|
||||||
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
both_conditionings = dict()
|
both_conditionings = dict()
|
||||||
for k in conditioning:
|
for k in conditioning:
|
||||||
if isinstance(conditioning[k], list):
|
if isinstance(conditioning[k], list):
|
||||||
@ -101,11 +120,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||||
else:
|
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
|
||||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
||||||
else:
|
return unconditioned_next_x, conditioned_next_x
|
||||||
#print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
|
||||||
|
|
||||||
|
def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||||
|
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||||
# slower non-batched path (20% slower on mac MPS)
|
# slower non-batched path (20% slower on mac MPS)
|
||||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||||
@ -127,14 +147,22 @@ class InvokeAIDiffuserComponent:
|
|||||||
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
|
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
|
||||||
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||||
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||||
|
|
||||||
CrossAttentionControl.clear_requests(self.model)
|
CrossAttentionControl.clear_requests(self.model)
|
||||||
|
|
||||||
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
def estimate_percent_through(self, step_index, sigma):
|
||||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
|
if step_index is not None and self.cross_attention_control_context is not None:
|
||||||
combined_next_x = unconditioned_next_x + scaled_delta
|
# percent_through will never reach 1.0 (but this is intended)
|
||||||
|
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||||
|
# find the best possible index of the current sigma in the sigma sequence
|
||||||
|
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
|
||||||
|
# flip because sigmas[0] is for the fully denoised image
|
||||||
|
# percent_through must be <1
|
||||||
|
return 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
|
||||||
|
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
||||||
|
|
||||||
return combined_next_x
|
|
||||||
|
|
||||||
# todo: make this work
|
# todo: make this work
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Reference in New Issue
Block a user