From 171f4aa71bd170caf810b76bbf30cd310a02b32b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 19 Jan 2023 16:16:35 -0500 Subject: [PATCH] [feat] Provide option to disable xformers from command line Starting `invoke.py` with --no-xformers will disable memory-efficient-attention support if xformers is installed. --xformers will enable support, but this is already the default. --- ldm/invoke/CLI.py | 10 +++------- ldm/invoke/args.py | 6 ++++++ ldm/invoke/generator/diffusers_pipeline.py | 3 ++- ldm/invoke/globals.py | 3 +++ 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index 6fb0efeb8d..ef6389c7cc 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -45,6 +45,7 @@ def main(): Globals.try_patchmatch = args.patchmatch Globals.always_use_cpu = args.always_use_cpu Globals.internet_available = args.internet_available and check_internet() + Globals.disable_xformers = not args.xformers print(f'>> Internet connectivity is {Globals.internet_available}') if not args.conf: @@ -124,7 +125,7 @@ def main(): # preload the model try: gen.load_model() - except KeyError as e: + except KeyError: pass except Exception as e: report_model_error(opt, e) @@ -731,11 +732,6 @@ def del_config(model_name:str, gen, opt, completer): completer.update_models(gen.model_manager.list_models()) def edit_model(model_name:str, gen, opt, completer): - current_model = gen.model_name -# if model_name == current_model: -# print("** Can't edit the active model. !switch to another model first. **") -# return - manager = gen.model_manager if not (info := manager.model_info(model_name)): print(f'** Unknown model {model_name}') @@ -887,7 +883,7 @@ def prepare_image_metadata( try: filename = opt.fnformat.format(**wildcards) except KeyError as e: - print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use \'{{prefix}}.{{seed}}.png\' instead') + print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use {{prefix}}.{{seed}}.png\' instead') filename = f'{prefix}.{seed}.png' except IndexError: print(f'** The filename format is broken or complete. Will use \'{{prefix}}.{{seed}}.png\' instead') diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 400d1f720d..c918e4fba7 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -482,6 +482,12 @@ class Args(object): action='store_true', help='Force free gpu memory before final decoding', ) + model_group.add_argument( + '--xformers', + action=argparse.BooleanOptionalAction, + default=True, + help='Enable/disable xformers support (default enabled if installed)', + ) model_group.add_argument( "--always_use_cpu", dest="always_use_cpu", diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 5e62abf9df..54e9d555af 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -39,6 +39,7 @@ from diffusers.utils.outputs import BaseOutput from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ldm.invoke.globals import Globals from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings from ldm.modules.textual_inversion_manager import TextualInversionManager @@ -306,7 +307,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): textual_inversion_manager=self.textual_inversion_manager ) - if is_xformers_available(): + if is_xformers_available() and not Globals.disable_xformers: self.enable_xformers_memory_efficient_attention() def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, diff --git a/ldm/invoke/globals.py b/ldm/invoke/globals.py index 137171aa33..5bd5597b78 100644 --- a/ldm/invoke/globals.py +++ b/ldm/invoke/globals.py @@ -43,6 +43,9 @@ Globals.always_use_cpu = False # The CLI will test connectivity at startup time. Globals.internet_available = True +# Whether to disable xformers +Globals.disable_xformers = False + # whether we are forcing full precision Globals.full_precision = False