add always_use_cpu arg to bypass MPS

This commit is contained in:
Damian Stewart 2022-12-04 15:15:39 +01:00
parent e0495a7440
commit f48706efee
5 changed files with 41 additions and 27 deletions

View File

@ -27,7 +27,7 @@ def main():
"""Initialize command-line parsers and the diffusion model""" """Initialize command-line parsers and the diffusion model"""
global infile global infile
print('* Initializing, be patient...') print('* Initializing, be patient...')
opt = Args() opt = Args()
args = opt.parse_args() args = opt.parse_args()
if not args: if not args:
@ -47,7 +47,8 @@ def main():
# alert - setting globals here # alert - setting globals here
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.')) Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.'))
Globals.try_patchmatch = args.patchmatch Globals.try_patchmatch = args.patchmatch
Globals.always_use_cpu = args.always_use_cpu
print(f'>> InvokeAI runtime directory is "{Globals.root}"') print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# loading here to avoid long delays on startup # loading here to avoid long delays on startup
@ -339,8 +340,8 @@ def main_loop(gen, opt):
filename, filename,
tool, tool,
formatted_dream_prompt, formatted_dream_prompt,
) )
if (not postprocessed) or opt.save_original: if (not postprocessed) or opt.save_original:
# only append to results if we didn't overwrite an earlier output # only append to results if we didn't overwrite an earlier output
results.append([path, formatted_dream_prompt]) results.append([path, formatted_dream_prompt])
@ -430,7 +431,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
add_embedding_terms(gen, completer) add_embedding_terms(gen, completer)
completer.add_history(command) completer.add_history(command)
operation = None operation = None
elif command.startswith('!models'): elif command.startswith('!models'):
gen.model_cache.print_models() gen.model_cache.print_models()
completer.add_history(command) completer.add_history(command)
@ -531,7 +532,7 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
completer.complete_extensions(('.yaml','.yml')) completer.complete_extensions(('.yaml','.yml'))
completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml' completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml'
done = False done = False
while not done: while not done:
new_config['config'] = input('Configuration file for this model: ') new_config['config'] = input('Configuration file for this model: ')
@ -562,7 +563,7 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
print('** Please enter a valid integer between 64 and 2048') print('** Please enter a valid integer between 64 and 2048')
make_default = input('Make this the default model? [n] ') in ('y','Y') make_default = input('Make this the default model? [n] ') in ('y','Y')
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default): if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
completer.add_model(model_name) completer.add_model(model_name)
@ -575,14 +576,14 @@ def del_config(model_name:str, gen, opt, completer):
gen.model_cache.commit(opt.conf) gen.model_cache.commit(opt.conf)
print(f'** {model_name} deleted') print(f'** {model_name} deleted')
completer.del_model(model_name) completer.del_model(model_name)
def edit_config(model_name:str, gen, opt, completer): def edit_config(model_name:str, gen, opt, completer):
config = gen.model_cache.config config = gen.model_cache.config
if model_name not in config: if model_name not in config:
print(f'** Unknown model {model_name}') print(f'** Unknown model {model_name}')
return return
print(f'\n>> Editing model {model_name} from configuration file {opt.conf}') print(f'\n>> Editing model {model_name} from configuration file {opt.conf}')
conf = config[model_name] conf = config[model_name]
@ -595,10 +596,10 @@ def edit_config(model_name:str, gen, opt, completer):
make_default = input('Make this the default model? [n] ') in ('y','Y') make_default = input('Make this the default model? [n] ') in ('y','Y')
completer.complete_extensions(None) completer.complete_extensions(None)
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default) write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False): def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
current_model = gen.model_name current_model = gen.model_name
op = 'modify' if clobber else 'import' op = 'modify' if clobber else 'import'
print('\n>> New configuration:') print('\n>> New configuration:')
if make_default: if make_default:
@ -621,7 +622,7 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak
gen.model_cache.set_default_model(model_name) gen.model_cache.set_default_model(model_name)
gen.model_cache.commit(conf_path) gen.model_cache.commit(conf_path)
do_switch = input(f'Keep model loaded? [y]') do_switch = input(f'Keep model loaded? [y]')
if len(do_switch)==0 or do_switch[0] in ('y','Y'): if len(do_switch)==0 or do_switch[0] in ('y','Y'):
pass pass
@ -651,7 +652,7 @@ def do_postprocess (gen, opt, callback):
opt.prompt = opt.new_prompt opt.prompt = opt.new_prompt
else: else:
opt.prompt = None opt.prompt = None
if os.path.dirname(file_path) == '': #basename given if os.path.dirname(file_path) == '': #basename given
file_path = os.path.join(opt.outdir,file_path) file_path = os.path.join(opt.outdir,file_path)
@ -716,7 +717,7 @@ def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command):
) )
meta['image']['postprocessing'] = pp meta['image']['postprocessing'] = pp
write_metadata(new_file,meta) write_metadata(new_file,meta)
def prepare_image_metadata( def prepare_image_metadata(
opt, opt,
prefix, prefix,
@ -794,21 +795,21 @@ def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
os.chdir( os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
) )
invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan) invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan)
try: try:
invoke_ai_web_server.run() invoke_ai_web_server.run()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
def add_embedding_terms(gen,completer): def add_embedding_terms(gen,completer):
''' '''
Called after setting the model, updates the autocompleter with Called after setting the model, updates the autocompleter with
any terms loaded by the embedding manager. any terms loaded by the embedding manager.
''' '''
completer.add_embedding_terms(gen.model.embedding_manager.list_terms()) completer.add_embedding_terms(gen.model.embedding_manager.list_terms())
def split_variations(variations_string) -> list: def split_variations(variations_string) -> list:
# shotgun parsing, woo # shotgun parsing, woo
parts = [] parts = []
@ -865,7 +866,7 @@ def make_step_callback(gen, opt, prefix):
image = gen.sample_to_image(img) image = gen.sample_to_image(img)
image.save(filename,'PNG') image.save(filename,'PNG')
return callback return callback
def retrieve_dream_command(opt,command,completer): def retrieve_dream_command(opt,command,completer):
''' '''
Given a full or partial path to a previously-generated image file, Given a full or partial path to a previously-generated image file,
@ -873,7 +874,7 @@ def retrieve_dream_command(opt,command,completer):
and pop it into the readline buffer (linux, Mac), or print out a comment and pop it into the readline buffer (linux, Mac), or print out a comment
for cut-and-paste (windows) for cut-and-paste (windows)
Given a wildcard path to a folder with image png files, Given a wildcard path to a folder with image png files,
will retrieve and format the dream command used to generate the images, will retrieve and format the dream command used to generate the images,
and save them to a file commands.txt for further processing and save them to a file commands.txt for further processing
''' '''
@ -909,7 +910,7 @@ def write_commands(opt, file_path:str, outfilepath:str):
except ValueError: except ValueError:
print(f'## "{basename}": unacceptable pattern') print(f'## "{basename}": unacceptable pattern')
return return
commands = [] commands = []
cmd = None cmd = None
for path in paths: for path in paths:
@ -938,7 +939,7 @@ def emergency_model_reconfigure():
print(' After reconfiguration is done, please relaunch invoke.py. ') print(' After reconfiguration is done, please relaunch invoke.py. ')
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print('configure_invokeai is launching....\n') print('configure_invokeai is launching....\n')
sys.argv = ['configure_invokeai','--interactive'] sys.argv = ['configure_invokeai','--interactive']
import configure_invokeai import configure_invokeai
configure_invokeai.main() configure_invokeai.main()

View File

@ -337,7 +337,7 @@ class Args(object):
if not hasattr(cmd_switches,name) and not hasattr(arg_switches,name): if not hasattr(cmd_switches,name) and not hasattr(arg_switches,name):
raise AttributeError raise AttributeError
value_arg,value_cmd = (None,None) value_arg,value_cmd = (None,None)
try: try:
value_cmd = getattr(cmd_switches,name) value_cmd = getattr(cmd_switches,name)
@ -393,7 +393,7 @@ class Args(object):
description= description=
""" """
Generate images using Stable Diffusion. Generate images using Stable Diffusion.
Use --web to launch the web interface. Use --web to launch the web interface.
Use --from_file to load prompts from a file path or standard input ("-"). Use --from_file to load prompts from a file path or standard input ("-").
Otherwise you will be dropped into an interactive command prompt (type -h for help.) Otherwise you will be dropped into an interactive command prompt (type -h for help.)
Other command-line arguments are defaults that can usually be overridden Other command-line arguments are defaults that can usually be overridden
@ -455,6 +455,12 @@ class Args(object):
action='store_true', action='store_true',
help='Force free gpu memory before final decoding', help='Force free gpu memory before final decoding',
) )
model_group.add_argument(
"--always_use_cpu",
dest="always_use_cpu",
action="store_true",
help="Force use of CPU even if GPU is available"
)
model_group.add_argument( model_group.add_argument(
'--precision', '--precision',
dest='precision', dest='precision',
@ -1036,7 +1042,7 @@ def metadata_dumps(opt,
Given an Args object, returns a dict containing the keys and Given an Args object, returns a dict containing the keys and
structure of the proposed stable diffusion metadata standard structure of the proposed stable diffusion metadata standard
https://github.com/lstein/stable-diffusion/discussions/392 https://github.com/lstein/stable-diffusion/discussions/392
This is intended to be turned into JSON and stored in the This is intended to be turned into JSON and stored in the
"sd "sd
''' '''
@ -1119,7 +1125,7 @@ def args_from_png(png_file_path) -> list[Args]:
meta = ldm.invoke.pngwriter.retrieve_metadata(png_file_path) meta = ldm.invoke.pngwriter.retrieve_metadata(png_file_path)
except AttributeError: except AttributeError:
return [legacy_metadata_load({},png_file_path)] return [legacy_metadata_load({},png_file_path)]
try: try:
return metadata_loads(meta) return metadata_loads(meta)
except: except:
@ -1218,4 +1224,4 @@ def legacy_metadata_load(meta,pathname) -> Args:
opt.prompt = '' opt.prompt = ''
opt.seed = 0 opt.seed = 0
return opt return opt

View File

@ -1,9 +1,12 @@
import torch import torch
from torch import autocast from torch import autocast
from contextlib import nullcontext from contextlib import nullcontext
from ldm.invoke.globals import Globals
def choose_torch_device() -> str: def choose_torch_device() -> str:
'''Convenience routine for guessing which GPU device to run model on''' '''Convenience routine for guessing which GPU device to run model on'''
if Globals.always_use_cpu:
return "cpu"
if torch.cuda.is_available(): if torch.cuda.is_available():
return 'cuda' return 'cuda'
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():

View File

@ -8,6 +8,7 @@ the attributes:
- root - the root directory under which "models" and "outputs" can be found - root - the root directory under which "models" and "outputs" can be found
- initfile - path to the initialization file - initfile - path to the initialization file
- try_patchmatch - option to globally disable loading of 'patchmatch' module - try_patchmatch - option to globally disable loading of 'patchmatch' module
- always_use_cpu - force use of CPU even if GPU is available
''' '''
import os import os
@ -24,3 +25,6 @@ Globals.initfile = os.path.expanduser('~/.invokeai')
# Awkward workaround to disable attempted loading of pypatchmatch # Awkward workaround to disable attempted loading of pypatchmatch
# which is causing CI tests to error out. # which is causing CI tests to error out.
Globals.try_patchmatch = True Globals.try_patchmatch = True
# Use CPU even if GPU is available (main use case is for debugging MPS issues)
Globals.always_use_cpu = False