mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
performance: low-memory option for calculating guidance sequentially (#2732)
In theory, this reduces peak memory consumption by doing the conditioned
and un-conditioned predictions one after the other instead of in a
single mini-batch.
In practice, it doesn't reduce the reported "Max VRAM used for this
generation" for me, even without xformers. (But it does slow things down
by a good 18%.)
That suggests to me that the peak memory usage is during VAE decoding,
not the diffusion unet, but ymmv. It does [improve things for gogurt's
16 GB
M1](https://github.com/invoke-ai/InvokeAI/pull/2732#issuecomment-1436187407),
so it seems worthwhile.
To try it out, use the `--sequential_guidance` option:
2dded68267/ldm/invoke/args.py (L487-L492)
This commit is contained in:
commit
7fadd5e5c4
@ -61,6 +61,7 @@ def main():
|
|||||||
Globals.always_use_cpu = args.always_use_cpu
|
Globals.always_use_cpu = args.always_use_cpu
|
||||||
Globals.internet_available = args.internet_available and check_internet()
|
Globals.internet_available = args.internet_available and check_internet()
|
||||||
Globals.disable_xformers = not args.xformers
|
Globals.disable_xformers = not args.xformers
|
||||||
|
Globals.sequential_guidance = args.sequential_guidance
|
||||||
Globals.ckpt_convert = args.ckpt_convert
|
Globals.ckpt_convert = args.ckpt_convert
|
||||||
|
|
||||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||||
|
@ -91,14 +91,14 @@ import pydoc
|
|||||||
import re
|
import re
|
||||||
import shlex
|
import shlex
|
||||||
import sys
|
import sys
|
||||||
import ldm.invoke
|
|
||||||
import ldm.invoke.pngwriter
|
|
||||||
|
|
||||||
from ldm.invoke.globals import Globals
|
|
||||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import ldm.invoke
|
||||||
|
import ldm.invoke.pngwriter
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
|
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
||||||
|
|
||||||
APP_ID = ldm.invoke.__app_id__
|
APP_ID = ldm.invoke.__app_id__
|
||||||
APP_NAME = ldm.invoke.__app_name__
|
APP_NAME = ldm.invoke.__app_name__
|
||||||
APP_VERSION = ldm.invoke.__version__
|
APP_VERSION = ldm.invoke.__version__
|
||||||
@ -488,6 +488,13 @@ class Args(object):
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='Force free gpu memory before final decoding',
|
help='Force free gpu memory before final decoding',
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--sequential_guidance',
|
||||||
|
dest='sequential_guidance',
|
||||||
|
action='store_true',
|
||||||
|
help="Calculate guidance in serial instead of in parallel, lowering memory requirement "
|
||||||
|
"at the expense of speed",
|
||||||
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--xformers',
|
'--xformers',
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
|
@ -13,8 +13,8 @@ the attributes:
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from pathlib import Path
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
Globals = Namespace()
|
Globals = Namespace()
|
||||||
@ -48,6 +48,9 @@ Globals.internet_available = True
|
|||||||
# Whether to disable xformers
|
# Whether to disable xformers
|
||||||
Globals.disable_xformers = False
|
Globals.disable_xformers = False
|
||||||
|
|
||||||
|
# Low-memory tradeoff for guidance calculations.
|
||||||
|
Globals.sequential_guidance = False
|
||||||
|
|
||||||
# whether we are forcing full precision
|
# whether we are forcing full precision
|
||||||
Globals.full_precision = False
|
Globals.full_precision = False
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import math
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import ceil
|
from math import ceil
|
||||||
@ -6,13 +5,20 @@ from typing import Callable, Optional, Union, Any, Dict
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.models.cross_attention import AttnProcessor
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||||
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
||||||
CrossAttentionType, SwapCrossAttnContext
|
CrossAttentionType, SwapCrossAttnContext
|
||||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
|
||||||
|
ModelForwardCallback: TypeAlias = Union[
|
||||||
|
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||||
|
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor],
|
||||||
|
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||||
|
]
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PostprocessingSettings:
|
class PostprocessingSettings:
|
||||||
@ -32,7 +38,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
* Hybrid conditioning (used for inpainting)
|
* Hybrid conditioning (used for inpainting)
|
||||||
'''
|
'''
|
||||||
debug_thresholding = False
|
debug_thresholding = False
|
||||||
last_percent_through = 0.0
|
sequential_guidance = False
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExtraConditioningInfo:
|
class ExtraConditioningInfo:
|
||||||
@ -45,8 +51,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
return self.cross_attention_control_args is not None
|
return self.cross_attention_control_args is not None
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model, model_forward_callback:
|
def __init__(self, model, model_forward_callback: ModelForwardCallback,
|
||||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str,Any]]], torch.Tensor],
|
|
||||||
is_running_diffusers: bool=False,
|
is_running_diffusers: bool=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -58,7 +63,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.is_running_diffusers = is_running_diffusers
|
self.is_running_diffusers = is_running_diffusers
|
||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
self.last_percent_through = 0.0
|
self.sequential_guidance = Globals.sequential_guidance
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def custom_attention_context(self,
|
def custom_attention_context(self,
|
||||||
@ -146,11 +151,20 @@ class InvokeAIDiffuserComponent:
|
|||||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||||
|
|
||||||
if wants_hybrid_conditioning:
|
if wants_hybrid_conditioning:
|
||||||
unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning)
|
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning,
|
||||||
|
conditioning)
|
||||||
elif wants_cross_attention_control:
|
elif wants_cross_attention_control:
|
||||||
unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma,
|
||||||
|
unconditioning,
|
||||||
|
conditioning,
|
||||||
|
cross_attention_control_types_to_do)
|
||||||
|
elif self.sequential_guidance:
|
||||||
|
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially(
|
||||||
|
x, sigma, unconditioning, conditioning)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
|
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(
|
||||||
|
x, sigma, unconditioning, conditioning)
|
||||||
|
|
||||||
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
||||||
|
|
||||||
@ -185,7 +199,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||||
# fast batched path
|
# fast batched path
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
@ -198,7 +212,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor):
|
||||||
|
# low-memory sequential path
|
||||||
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||||
|
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
|
||||||
|
if conditioned_next_x.device.type == 'mps':
|
||||||
|
# prevent a result filled with zeros. seems to be a torch bug.
|
||||||
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||||
assert isinstance(conditioning, dict)
|
assert isinstance(conditioning, dict)
|
||||||
assert isinstance(unconditioning, dict)
|
assert isinstance(unconditioning, dict)
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
@ -216,18 +240,21 @@ class InvokeAIDiffuserComponent:
|
|||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
def apply_cross_attention_controlled_conditioning(self,
|
def _apply_cross_attention_controlled_conditioning(self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do):
|
cross_attention_control_types_to_do):
|
||||||
if self.is_running_diffusers:
|
if self.is_running_diffusers:
|
||||||
return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning,
|
||||||
|
conditioning,
|
||||||
|
cross_attention_control_types_to_do)
|
||||||
else:
|
else:
|
||||||
return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning,
|
||||||
|
cross_attention_control_types_to_do)
|
||||||
|
|
||||||
def apply_cross_attention_controlled_conditioning__diffusers(self,
|
def _apply_cross_attention_controlled_conditioning__diffusers(self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
unconditioning,
|
||||||
@ -250,7 +277,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||||
# slower non-batched path (20% slower on mac MPS)
|
# slower non-batched path (20% slower on mac MPS)
|
||||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||||
|
Loading…
x
Reference in New Issue
Block a user