[feat] Provide option to disable xformers from command line

Starting `invoke.py` with --no-xformers will disable
memory-efficient-attention support if xformers is installed.

--xformers will enable support, but this is already the
default.
This commit is contained in:
Lincoln Stein 2023-01-19 16:16:35 -05:00
parent ce17051b28
commit 171f4aa71b
4 changed files with 14 additions and 8 deletions

View File

@ -45,6 +45,7 @@ def main():
Globals.try_patchmatch = args.patchmatch Globals.try_patchmatch = args.patchmatch
Globals.always_use_cpu = args.always_use_cpu Globals.always_use_cpu = args.always_use_cpu
Globals.internet_available = args.internet_available and check_internet() Globals.internet_available = args.internet_available and check_internet()
Globals.disable_xformers = not args.xformers
print(f'>> Internet connectivity is {Globals.internet_available}') print(f'>> Internet connectivity is {Globals.internet_available}')
if not args.conf: if not args.conf:
@ -124,7 +125,7 @@ def main():
# preload the model # preload the model
try: try:
gen.load_model() gen.load_model()
except KeyError as e: except KeyError:
pass pass
except Exception as e: except Exception as e:
report_model_error(opt, e) report_model_error(opt, e)
@ -731,11 +732,6 @@ def del_config(model_name:str, gen, opt, completer):
completer.update_models(gen.model_manager.list_models()) completer.update_models(gen.model_manager.list_models())
def edit_model(model_name:str, gen, opt, completer): def edit_model(model_name:str, gen, opt, completer):
current_model = gen.model_name
# if model_name == current_model:
# print("** Can't edit the active model. !switch to another model first. **")
# return
manager = gen.model_manager manager = gen.model_manager
if not (info := manager.model_info(model_name)): if not (info := manager.model_info(model_name)):
print(f'** Unknown model {model_name}') print(f'** Unknown model {model_name}')
@ -887,7 +883,7 @@ def prepare_image_metadata(
try: try:
filename = opt.fnformat.format(**wildcards) filename = opt.fnformat.format(**wildcards)
except KeyError as e: except KeyError as e:
print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use \'{{prefix}}.{{seed}}.png\' instead') print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use {{prefix}}.{{seed}}.png\' instead')
filename = f'{prefix}.{seed}.png' filename = f'{prefix}.{seed}.png'
except IndexError: except IndexError:
print(f'** The filename format is broken or complete. Will use \'{{prefix}}.{{seed}}.png\' instead') print(f'** The filename format is broken or complete. Will use \'{{prefix}}.{{seed}}.png\' instead')

View File

@ -482,6 +482,12 @@ class Args(object):
action='store_true', action='store_true',
help='Force free gpu memory before final decoding', help='Force free gpu memory before final decoding',
) )
model_group.add_argument(
'--xformers',
action=argparse.BooleanOptionalAction,
default=True,
help='Enable/disable xformers support (default enabled if installed)',
)
model_group.add_argument( model_group.add_argument(
"--always_use_cpu", "--always_use_cpu",
dest="always_use_cpu", dest="always_use_cpu",

View File

@ -39,6 +39,7 @@ from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ldm.invoke.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager from ldm.modules.textual_inversion_manager import TextualInversionManager
@ -306,7 +307,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
textual_inversion_manager=self.textual_inversion_manager textual_inversion_manager=self.textual_inversion_manager
) )
if is_xformers_available(): if is_xformers_available() and not Globals.disable_xformers:
self.enable_xformers_memory_efficient_attention() self.enable_xformers_memory_efficient_attention()
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,

View File

@ -43,6 +43,9 @@ Globals.always_use_cpu = False
# The CLI will test connectivity at startup time. # The CLI will test connectivity at startup time.
Globals.internet_available = True Globals.internet_available = True
# Whether to disable xformers
Globals.disable_xformers = False
# whether we are forcing full precision # whether we are forcing full precision
Globals.full_precision = False Globals.full_precision = False