mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
enable fast switching between models in invoke.py
- This PR enables two new commands in the invoke.py script !models -- list the available models and their cache status !switch <model> -- switch to the indicated model Example: invoke> !models laion400m not loaded Latent Diffusion LAION400M model stable-diffusion-1.4 active Stable Diffusion inference model version 1.4 waifu-1.3 cached Waifu anime model version 1.3 invoke> !switch waifu-1.3 >> Caching model stable-diffusion-1.4 in system RAM >> Retrieving model waifu-1.3 from system RAM cache The name and descriptions of the models are taken from `config/models.yaml`. A future enhancement to `model_cache.py` will be to enable new model stanzas to be added to the file programmatically. This will be useful for the WebGUI. More details: - Use fast switching algorithm described in PR #948 - Models are selected using their configuration stanza name given in models.yaml. - To avoid filling up CPU RAM with cached models, this PR implements an LRU cache that monitors available CPU RAM. - The caching code allows the minimum value of available RAM to be adjusted, but invoke.py does not currently have a command-line argument that allows you to set it. The minimum free RAM is arbitrarily set to 2 GB. - Add optional description field to configs/models.yaml Unrelated fixes: - Added ">>" to CompViz model loading messages in order to make user experience more consistent. - When generating an image greater than defaults, will only warn about possible VRAM filling the first time. - Fixed bug that was causing help message to be printed twice. This involved moving the import line for the web backend into the section where it is called. Coauthored by: @ArDiouscuros
This commit is contained in:
@ -16,8 +16,6 @@ from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||
from ldm.invoke.image_util import make_grid
|
||||
from ldm.invoke.log import write_log
|
||||
from omegaconf import OmegaConf
|
||||
from backend.invoke_ai_web_server import InvokeAIWebServer
|
||||
|
||||
|
||||
def main():
|
||||
"""Initialize command-line parsers and the diffusion model"""
|
||||
@ -33,7 +31,7 @@ def main():
|
||||
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
|
||||
sys.exit(-1)
|
||||
|
||||
print('* Initializing, be patient...\n')
|
||||
print('* Initializing, be patient...')
|
||||
from ldm.generate import Generate
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
@ -42,45 +40,7 @@ def main():
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# Loading Face Restoration and ESRGAN Modules
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if opt.restore or opt.esrgan:
|
||||
from ldm.invoke.restoration import Restoration
|
||||
restoration = Restoration()
|
||||
if opt.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_dir, opt.gfpgan_model_path)
|
||||
else:
|
||||
print('>> Face restoration disabled')
|
||||
if opt.esrgan:
|
||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||
else:
|
||||
print('>> Upscaling disabled')
|
||||
else:
|
||||
print('>> Face restoration and upscaling disabled')
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
|
||||
# creating a simple text2image object with a handful of
|
||||
# defaults passed on the command line.
|
||||
# additional parameters will be added (or overriden) during
|
||||
# the user input loop
|
||||
try:
|
||||
gen = Generate(
|
||||
conf = opt.conf,
|
||||
model = opt.model,
|
||||
sampler_name = opt.sampler_name,
|
||||
embedding_path = opt.embedding_path,
|
||||
full_precision = opt.full_precision,
|
||||
precision = opt.precision,
|
||||
gfpgan=gfpgan,
|
||||
codeformer=codeformer,
|
||||
esrgan=esrgan,
|
||||
free_gpu_mem=opt.free_gpu_mem,
|
||||
)
|
||||
except (FileNotFoundError, IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
gfpgan,codeformer,esrgan = load_face_restoration(opt)
|
||||
|
||||
# make sure the output directory exists
|
||||
if not os.path.exists(opt.outdir):
|
||||
@ -100,6 +60,24 @@ def main():
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
# creating a Generate object:
|
||||
try:
|
||||
gen = Generate(
|
||||
conf = opt.conf,
|
||||
model = opt.model,
|
||||
sampler_name = opt.sampler_name,
|
||||
embedding_path = opt.embedding_path,
|
||||
full_precision = opt.full_precision,
|
||||
precision = opt.precision,
|
||||
gfpgan=gfpgan,
|
||||
codeformer=codeformer,
|
||||
esrgan=esrgan,
|
||||
free_gpu_mem=opt.free_gpu_mem,
|
||||
)
|
||||
except (FileNotFoundError, IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
if opt.seamless:
|
||||
print(">> changed to seamless tiling mode")
|
||||
|
||||
@ -124,12 +102,12 @@ def main_loop(gen, opt, infile):
|
||||
done = False
|
||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||
last_results = list()
|
||||
model_config = OmegaConf.load(opt.conf)[opt.model]
|
||||
model_config = OmegaConf.load(opt.conf)
|
||||
|
||||
# The readline completer reads history from the .dream_history file located in the
|
||||
# output directory specified at the time of script launch. We do not currently support
|
||||
# changing the history file midstream when the output directory is changed.
|
||||
completer = get_completer(opt)
|
||||
completer = get_completer(opt, models=list(model_config.keys()))
|
||||
output_cntr = completer.get_current_history_length()+1
|
||||
|
||||
# os.pathconf is not available on Windows
|
||||
@ -173,6 +151,16 @@ def main_loop(gen, opt, infile):
|
||||
command = command.replace('!fix ','',1)
|
||||
operation = 'postprocess'
|
||||
|
||||
elif subcommand.startswith('switch'):
|
||||
model_name = command.replace('!switch ','',1)
|
||||
gen.set_model(model_name)
|
||||
continue
|
||||
|
||||
elif subcommand.startswith('models'):
|
||||
model_name = command.replace('!models ','',1)
|
||||
gen.model_cache.print_models()
|
||||
continue
|
||||
|
||||
elif subcommand.startswith('fetch'):
|
||||
file_path = command.replace('!fetch ','',1)
|
||||
retrieve_dream_command(opt,file_path,completer)
|
||||
@ -218,9 +206,9 @@ def main_loop(gen, opt, infile):
|
||||
|
||||
# width and height are set by model if not specified
|
||||
if not opt.width:
|
||||
opt.width = model_config.width
|
||||
opt.width = gen.width
|
||||
if not opt.height:
|
||||
opt.height = model_config.height
|
||||
opt.height = gen.height
|
||||
|
||||
# retrieve previous value of init image if requested
|
||||
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
|
||||
@ -509,6 +497,7 @@ def get_next_command(infile=None) -> str: # command string
|
||||
|
||||
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
|
||||
print('\n* --web was specified, starting web server...')
|
||||
from backend.invoke_ai_web_server import InvokeAIWebServer
|
||||
# Change working directory to the stable-diffusion directory
|
||||
os.chdir(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
@ -547,6 +536,28 @@ def split_variations(variations_string) -> list:
|
||||
else:
|
||||
return parts
|
||||
|
||||
def load_face_restoration(opt):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if opt.restore or opt.esrgan:
|
||||
from ldm.invoke.restoration import Restoration
|
||||
restoration = Restoration()
|
||||
if opt.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_dir, opt.gfpgan_model_path)
|
||||
else:
|
||||
print('>> Face restoration disabled')
|
||||
if opt.esrgan:
|
||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||
else:
|
||||
print('>> Upscaling disabled')
|
||||
else:
|
||||
print('>> Face restoration and upscaling disabled')
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
return gfpgan,codeformer,esrgan
|
||||
|
||||
|
||||
def retrieve_dream_command(opt,file_path,completer):
|
||||
'''
|
||||
Given a full or partial path to a previously-generated image file,
|
||||
|
Reference in New Issue
Block a user