performance(InvokeAIDiffuserComponent): add low-memory path for calculating conditioned and unconditioned predictions sequentially

Proof of concept. Still needs to be wired up to options or heuristics.
This commit is contained in:
Kevin Turner 2023-02-19 16:04:54 -08:00
parent aca9d74489
commit d0abe13b60

View File

@ -29,7 +29,7 @@ class InvokeAIDiffuserComponent:
* Hybrid conditioning (used for inpainting)
'''
debug_thresholding = False
sequential_conditioning = False
@dataclass
class ExtraConditioningInfo:
@ -149,9 +149,13 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do)
elif self.sequential_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning)
else:
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(x, sigma, unconditioning,
conditioning)
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning)
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
@ -198,6 +202,16 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x
def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor):
# low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
if conditioned_next_x.device.type == 'mps':
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict)