performance: low-memory option for calculating guidance sequentially (#2732)

In theory, this reduces peak memory consumption by doing the conditioned
and un-conditioned predictions one after the other instead of in a
single mini-batch.

In practice, it doesn't reduce the reported "Max VRAM used for this
generation" for me, even without xformers. (But it does slow things down
by a good 18%.)

That suggests to me that the peak memory usage is during VAE decoding,
not the diffusion unet, but ymmv. It does [improve things for gogurt's
16 GB
M1](https://github.com/invoke-ai/InvokeAI/pull/2732#issuecomment-1436187407),
so it seems worthwhile.

To try it out, use the `--sequential_guidance` option:
2dded68267/ldm/invoke/args.py (L487-L492)
This commit is contained in:
Lincoln Stein 2023-02-20 23:00:54 -05:00 committed by GitHub
commit 7fadd5e5c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 25 deletions

View File

@ -61,6 +61,7 @@ def main():
Globals.always_use_cpu = args.always_use_cpu Globals.always_use_cpu = args.always_use_cpu
Globals.internet_available = args.internet_available and check_internet() Globals.internet_available = args.internet_available and check_internet()
Globals.disable_xformers = not args.xformers Globals.disable_xformers = not args.xformers
Globals.sequential_guidance = args.sequential_guidance
Globals.ckpt_convert = args.ckpt_convert Globals.ckpt_convert = args.ckpt_convert
print(f">> Internet connectivity is {Globals.internet_available}") print(f">> Internet connectivity is {Globals.internet_available}")

View File

@ -91,14 +91,14 @@ import pydoc
import re import re
import shlex import shlex
import sys import sys
import ldm.invoke
import ldm.invoke.pngwriter
from ldm.invoke.globals import Globals
from ldm.invoke.prompt_parser import split_weighted_subprompts
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
import ldm.invoke
import ldm.invoke.pngwriter
from ldm.invoke.globals import Globals
from ldm.invoke.prompt_parser import split_weighted_subprompts
APP_ID = ldm.invoke.__app_id__ APP_ID = ldm.invoke.__app_id__
APP_NAME = ldm.invoke.__app_name__ APP_NAME = ldm.invoke.__app_name__
APP_VERSION = ldm.invoke.__version__ APP_VERSION = ldm.invoke.__version__
@ -488,6 +488,13 @@ class Args(object):
action='store_true', action='store_true',
help='Force free gpu memory before final decoding', help='Force free gpu memory before final decoding',
) )
model_group.add_argument(
'--sequential_guidance',
dest='sequential_guidance',
action='store_true',
help="Calculate guidance in serial instead of in parallel, lowering memory requirement "
"at the expense of speed",
)
model_group.add_argument( model_group.add_argument(
'--xformers', '--xformers',
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,

View File

@ -13,8 +13,8 @@ the attributes:
import os import os
import os.path as osp import os.path as osp
from pathlib import Path
from argparse import Namespace from argparse import Namespace
from pathlib import Path
from typing import Union from typing import Union
Globals = Namespace() Globals = Namespace()
@ -48,6 +48,9 @@ Globals.internet_available = True
# Whether to disable xformers # Whether to disable xformers
Globals.disable_xformers = False Globals.disable_xformers = False
# Low-memory tradeoff for guidance calculations.
Globals.sequential_guidance = False
# whether we are forcing full precision # whether we are forcing full precision
Globals.full_precision = False Globals.full_precision = False

View File

@ -1,4 +1,3 @@
import math
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil from math import ceil
@ -6,13 +5,20 @@ from typing import Callable, Optional, Union, Any, Dict
import numpy as np import numpy as np
import torch import torch
from diffusers.models.cross_attention import AttnProcessor from diffusers.models.cross_attention import AttnProcessor
from typing_extensions import TypeAlias
from ldm.invoke.globals import Globals
from ldm.models.diffusion.cross_attention_control import Arguments, \ from ldm.models.diffusion.cross_attention_control import Arguments, \
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \ restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext CrossAttentionType, SwapCrossAttnContext
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs]
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
]
@dataclass(frozen=True) @dataclass(frozen=True)
class PostprocessingSettings: class PostprocessingSettings:
@ -32,7 +38,7 @@ class InvokeAIDiffuserComponent:
* Hybrid conditioning (used for inpainting) * Hybrid conditioning (used for inpainting)
''' '''
debug_thresholding = False debug_thresholding = False
last_percent_through = 0.0 sequential_guidance = False
@dataclass @dataclass
class ExtraConditioningInfo: class ExtraConditioningInfo:
@ -45,8 +51,7 @@ class InvokeAIDiffuserComponent:
return self.cross_attention_control_args is not None return self.cross_attention_control_args is not None
def __init__(self, model, model_forward_callback: def __init__(self, model, model_forward_callback: ModelForwardCallback,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str,Any]]], torch.Tensor],
is_running_diffusers: bool=False, is_running_diffusers: bool=False,
): ):
""" """
@ -58,7 +63,7 @@ class InvokeAIDiffuserComponent:
self.is_running_diffusers = is_running_diffusers self.is_running_diffusers = is_running_diffusers
self.model_forward_callback = model_forward_callback self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None self.cross_attention_control_context = None
self.last_percent_through = 0.0 self.sequential_guidance = Globals.sequential_guidance
@contextmanager @contextmanager
def custom_attention_context(self, def custom_attention_context(self,
@ -146,11 +151,20 @@ class InvokeAIDiffuserComponent:
wants_hybrid_conditioning = isinstance(conditioning, dict) wants_hybrid_conditioning = isinstance(conditioning, dict)
if wants_hybrid_conditioning: if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning) unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning,
conditioning)
elif wants_cross_attention_control: elif wants_cross_attention_control:
unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do)
elif self.sequential_guidance:
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning)
else: else:
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning) unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning)
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale) combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
@ -185,7 +199,7 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class. # methods below are called from do_diffusion_step and should be considered private to this class.
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
# fast batched path # fast batched path
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
@ -198,7 +212,17 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor):
# low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
if conditioned_next_x.device.type == 'mps':
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
assert isinstance(conditioning, dict) assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict) assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
@ -216,18 +240,21 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning(self, def _apply_cross_attention_controlled_conditioning(self,
x: torch.Tensor, x: torch.Tensor,
sigma, sigma,
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do): cross_attention_control_types_to_do):
if self.is_running_diffusers: if self.is_running_diffusers:
return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning,
conditioning,
cross_attention_control_types_to_do)
else: else:
return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning,
cross_attention_control_types_to_do)
def apply_cross_attention_controlled_conditioning__diffusers(self, def _apply_cross_attention_controlled_conditioning__diffusers(self,
x: torch.Tensor, x: torch.Tensor,
sigma, sigma,
unconditioning, unconditioning,
@ -250,7 +277,7 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS) # slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of