enable fast switching between models in invoke.py

- This PR enables two new commands in the invoke.py script

 !models         -- list the available models and their cache status
 !switch <model> -- switch to the indicated model

Example:

 invoke> !models
   laion400m            not loaded  Latent Diffusion LAION400M model
   stable-diffusion-1.4     active  Stable Diffusion inference model version 1.4
   waifu-1.3                cached  Waifu anime model version 1.3
 invoke> !switch waifu-1.3
   >> Caching model stable-diffusion-1.4 in system RAM
   >> Retrieving model waifu-1.3 from system RAM cache

The name and descriptions of the models are taken from
`config/models.yaml`. A future enhancement to `model_cache.py` will be
to enable new model stanzas to be added to the file
programmatically. This will be useful for the WebGUI.

More details:

- Use fast switching algorithm described in PR #948
- Models are selected using their configuration stanza name
  given in models.yaml.
- To avoid filling up CPU RAM with cached models, this PR
  implements an LRU cache that monitors available CPU RAM.
- The caching code allows the minimum value of available RAM
  to be adjusted, but invoke.py does not currently have a
  command-line argument that allows you to set it. The
  minimum free RAM is arbitrarily set to 2 GB.
- Add optional description field to configs/models.yaml

Unrelated fixes:
- Added ">>" to CompViz model loading messages in order to make user experience
  more consistent.
- When generating an image greater than defaults, will only warn about possible
  VRAM filling the first time.
- Fixed bug that was causing help message to be printed twice. This involved
  moving the import line for the web backend into the section where it is
  called.

Coauthored by: @ArDiouscuros
This commit is contained in:
Lincoln Stein 2022-10-12 02:14:59 -04:00
parent b9e910b5f4
commit 488334710b
7 changed files with 219 additions and 191 deletions

View File

@ -33,6 +33,7 @@ from ldm.invoke.args import metadata_from_png
from ldm.invoke.image_util import InitImageResizer
from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.conditioning import get_uc_and_c
from ldm.invoke.model_cache import ModelCache
"""Simplified text to image API for stable diffusion/latent diffusion
@ -123,12 +124,11 @@ class Generate:
esrgan=None,
free_gpu_mem=False,
):
models = OmegaConf.load(conf)
mconfig = models[model]
self.weights = mconfig.weights if weights is None else weights
self.config = mconfig.config if config is None else config
self.height = mconfig.height
self.width = mconfig.width
mconfig = OmegaConf.load(conf)
self.model_name = model
self.height = None
self.width = None
self.model_cache = None
self.iterations = 1
self.steps = 50
self.cfg_scale = 7.5
@ -139,6 +139,7 @@ class Generate:
self.seamless = False
self.embedding_path = embedding_path
self.model = None # empty for now
self.model_hash = None
self.sampler = None
self.device = None
self.session_peakmem = None
@ -149,6 +150,7 @@ class Generate:
self.codeformer = codeformer
self.esrgan = esrgan
self.free_gpu_mem = free_gpu_mem
self.size_matters = True # used to warn once about large image sizes and VRAM
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@ -164,6 +166,9 @@ class Generate:
if self.precision == 'auto':
self.precision = choose_precision(self.device)
# model caching system for fast switching
self.model_cache = ModelCache(mconfig,self.device,self.precision)
# for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
transformers.logging.set_verbosity_error()
@ -294,7 +299,12 @@ class Generate:
with_variations = [] if with_variations is None else with_variations
# will instantiate the model or return it from cache
model = self.load_model()
model = self.set_model(self.model_name)
# self.width and self.height are set by set_model()
# to the width and height of the image training set
width = width or self.width
height = height or self.height
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
@ -584,8 +594,9 @@ class Generate:
# this returns a torch tensor
init_mask = self._create_init_mask(image, width, height, fit=fit)
if (image.width * image.height) > (self.width * self.height):
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
self.size_matters = False
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
@ -635,29 +646,47 @@ class Generate:
return self.generators['inpaint']
def load_model(self):
"""Load and initialize the model from configuration variables passed at object creation time"""
if self.model is None:
'''
preload model identified in self.model_name
'''
self.set_model(self.model_name)
def set_model(self,model_name):
"""
Given the name of a model defined in models.yaml, will load and initialize it
and return the model object. Previously-used models will be cached.
"""
if self.model_name == model_name and self.model is not None:
return self.model
model_data = self.model_cache.get_model(model_name)
if model_data is None or len(model_data) == 0:
print(f'** Model switch failed **')
return self.model
self.model = model_data['model']
self.width = model_data['width']
self.height= model_data['height']
self.model_hash = model_data['hash']
# uncache generators so they pick up new models
self.generators = {}
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
try:
model = self._load_model_from_config(self.config, self.weights)
if self.embedding_path is not None:
model.embedding_manager.load(
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
)
self.model = model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
self.model.cond_stage_model.device = self.device
except AttributeError as e:
print(f'>> Error loading model. {str(e)}', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
raise SystemExit from e
self._set_sampler()
for m in self.model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m._orig_padding_mode = m.padding_mode
self.model_name = model_name
return self.model
def correct_colors(self,
@ -761,53 +790,6 @@ class Generate:
print(msg)
# Be warned: config is the path to the model config file, not the invoke conf file!
# Also note that we can get config and weights from self, so why do we need to
# pass them as args?
def _load_model_from_config(self, config, weights):
print(f'>> Loading model from {weights}')
# for usage statistics
device_type = choose_torch_device()
if device_type == 'cuda':
torch.cuda.reset_peak_memory_stats()
tic = time.time()
# this does the work
c = OmegaConf.load(config)
with open(weights,'rb') as f:
weight_bytes = f.read()
self.model_hash = self._cached_sha256(weights,weight_bytes)
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
del weight_bytes
sd = pl_sd['state_dict']
model = instantiate_from_config(c.model)
m, u = model.load_state_dict(sd, strict=False)
if self.precision == 'float16':
print('>> Using faster float16 precision')
model.to(torch.float16)
else:
print('>> Using more accurate float32 precision')
model.to(self.device)
model.eval()
# usage statistics
toc = time.time()
print(
f'>> Model loaded in', '%4.2fs' % (toc - tic)
)
if self._has_cuda():
print(
'>> Max VRAM used to load the model:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
'\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)
return model
def _load_img(self, img, width, height)->Image:
if isinstance(img, Image.Image):
image = img
@ -951,26 +933,6 @@ class Generate:
def _has_cuda(self):
return self.device.type == 'cuda'
def _cached_sha256(self,path,data):
dirname = os.path.dirname(path)
basename = os.path.basename(path)
base, _ = os.path.splitext(basename)
hashpath = os.path.join(dirname,base+'.sha256')
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
with open(hashpath) as f:
hash = f.read()
return hash
print(f'>> Calculating sha256 hash of weights file')
tic = time.time()
sha = hashlib.sha256()
sha.update(data)
hash = sha.hexdigest()
toc = time.time()
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
with open(hashpath,'w') as f:
f.write(hash)
return hash
def write_intermediate_images(self,modulus,path):
counter = -1
if not os.path.exists(path):

View File

@ -20,22 +20,35 @@ from ldm.util import instantiate_from_config
GIGS=2**30
AVG_MODEL_SIZE=2.1*GIGS
DEFAULT_MIN_AVAIL=2*GIGS
class ModelCache(object):
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_free_mem=2*GIGS):
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_avail_mem=DEFAULT_MIN_AVAIL):
'''
Initialize with the path to the models.yaml config file,
the torch device type, and precision. The optional
min_avail_mem argument specifies how much unused system
(CPU) memory to preserve. The cache of models in RAM will
grow until this value is approached. Default is 2G.
'''
# prevent nasty-looking CLIP log message
transformers.logging.set_verbosity_error()
self.config = config
self.precision = precision
self.device = torch.device(device_type)
self.min_free_mem = min_free_mem
self.min_avail_mem = min_avail_mem
self.models = {}
self.stack = [] # this is an LRU FIFO
self.current_model = None
def get_model(self, model_name:str):
'''
Given a model named identified in models.yaml, return
the model object. If in RAM will load into GPU VRAM.
If on disk, will load from there.
'''
if model_name not in self.config:
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
return None
if self.current_model != model_name:
@ -43,22 +56,42 @@ class ModelCache(object):
if model_name in self.models:
requested_model = self.models[model_name]['model']
self._model_from_cpu(requested_model)
print(f'>> Retrieving model {model_name} from system RAM cache')
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
width = self.models[model_name]['width']
height = self.models[model_name]['height']
hash = self.models[model_name]['hash']
else:
self._check_memory()
requested_model, width, height = self._load_model(model_name)
try:
requested_model, width, height, hash = self._load_model(model_name)
self.models[model_name] = {}
self.models[model_name]['model'] = requested_model
self.models[model_name]['width'] = width
self.models[model_name]['height'] = height
self.models[model_name]['hash'] = hash
except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}')
return {}
self.current_model = model_name
self._push_newest_model(model_name)
return requested_model, width, height
return {
'model':requested_model,
'width':width,
'height':height,
'hash': hash
}
def list_models(self):
def list_models(self) -> dict:
'''
Return a dict of models in the format:
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
'description': description,
},
model_name2: { etc }
'''
result = {}
for name in self.config:
try:
description = self.config[name].description
@ -70,29 +103,27 @@ class ModelCache(object):
status = 'cached'
else:
status = 'not loaded'
print(f'{name:20s} {status:>10s} {description}')
result[name]={}
result[name]['status']=status
result[name]['description']=description
return result
def print_models(self):
'''
Print a table of models, their descriptions, and load status
'''
models = self.list_models()
for name in models:
print(f'{name:20s} {models[name]["status"]:>10s} {models[name]["description"]}')
def _check_memory(self):
free_memory = psutil.virtual_memory()[4]
print(f'DEBUG: free memory = {free_memory}, min_mem = {self.min_free_mem}')
while free_memory + AVG_MODEL_SIZE < self.min_free_mem:
print(f'DEBUG: free memory = {free_memory}')
avail_memory = psutil.virtual_memory()[1]
if avail_memory + AVG_MODEL_SIZE < self.min_avail_mem:
least_recent_model = self._pop_oldest_model()
if least_recent_model is None:
return
print(f'DEBUG: clearing {least_recent_model} from cache (refcount = {getrefcount(self.models[least_recent_model]["model"])})')
del self.models[least_recent_model]['model']
if least_recent_model is not None:
del self.models[least_recent_model]
gc.collect()
new_free_memory = psutil.virtual_memory()[4]
if new_free_memory <= free_memory:
print(f'>> **Unable to free memory for model caching.**')
break;
free_memory = new_free_memory
def _load_model(self, model_name:str):
"""Load and initialize the model from configuration variables passed at object creation time"""
@ -106,18 +137,20 @@ class ModelCache(object):
width = mconfig.width
height = mconfig.height
print(f'>> Loading {model_name} weights from {weights}')
print(f'>> Loading {model_name} from {weights}')
# for usage statistics
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
tic = time.time()
# this does the work
c = OmegaConf.load(config)
with open(weights,'rb') as f:
weight_bytes = f.read()
self.model_hash = self._cached_sha256(weights,weight_bytes)
model_hash = self._cached_sha256(weights,weight_bytes)
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
del weight_bytes
sd = pl_sd['state_dict']
@ -143,39 +176,40 @@ class ModelCache(object):
'\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)
return model, width, height
return model, width, height, model_hash
def unload_model(self, model_name:str):
if model_name not in self.models:
return
print(f'>> Unloading model {model_name}')
print(f'>> Caching model {model_name} in system RAM')
model = self.models[model_name]['model']
self._model_to_cpu(model)
self.models[model_name]['model'] = self._model_to_cpu(model)
gc.collect()
if self._has_cuda():
torch.cuda.empty_cache()
def _model_to_cpu(self,model):
if self._has_cuda():
print(f'DEBUG: moving model to cpu')
model.first_stage_model.to('cpu')
model.cond_stage_model.to('cpu')
model.model.to('cpu')
return model.to('cpu')
def _model_from_cpu(self,model):
if self._has_cuda():
print(f'DEBUG: moving model into {self.device.type}')
model.to(self.device)
model.first_stage_model.to(self.device)
model.cond_stage_model.to(self.device)
return model
def _pop_oldest_model(self):
'''
Remove the first element of the FIFO, which ought
to be the least recently accessed model.
to be the least recently accessed model. Do not
pop the last one, because it is in active use!
'''
if len(self.stack)>0:
self.stack.pop(0)
if len(self.stack) > 1:
return self.stack.pop(0)
def _push_newest_model(self,model_name:str):
'''
@ -187,7 +221,6 @@ class ModelCache(object):
except ValueError:
pass
self.stack.append(model_name)
print(f'DEBUG, stack={self.stack}')
def _has_cuda(self):
return self.device.type == 'cuda'

View File

@ -47,7 +47,10 @@ COMMANDS = (
'--skip_normalize','-x',
'--log_tokenization','-t',
'--hires_fix',
'!fix','!fetch','!history','!search','!clear',
'!fix','!fetch','!history','!search','!clear','!models','!switch',
)
MODEL_COMMANDS = (
'!switch',
)
IMG_PATH_COMMANDS = (
'--outdir[=\s]',
@ -63,8 +66,9 @@ IMG_FILE_COMMANDS=(
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
class Completer(object):
def __init__(self, options):
def __init__(self, options, models=[]):
self.options = sorted(options)
self.models = sorted(models)
self.seeds = set()
self.matches = list()
self.default_dir = None
@ -88,6 +92,9 @@ class Completer(object):
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
self.matches= self._seed_completions(text,state)
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
self.matches= self._model_completions(text,state)
# This is the first time for this text, so build a match list.
elif text:
self.matches = [
@ -188,6 +195,21 @@ class Completer(object):
matches.sort()
return matches
def _model_completions(self, text, state):
m = re.search('(!switch\s+)(\w*)',text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else:
switch = ''
partial = text
matches = list()
for s in self.models:
if s.startswith(partial):
matches.append(switch+s)
matches.sort()
return matches
def _pre_input_hook(self):
if self.linebuffer:
readline.insert_text(self.linebuffer)
@ -266,9 +288,9 @@ class DummyCompleter(Completer):
def set_line(self,line):
print(f'# {line}')
def get_completer(opt:Args)->Completer:
def get_completer(opt:Args, models=[])->Completer:
if readline_available:
completer = Completer(COMMANDS)
completer = Completer(COMMANDS,models)
readline.set_completer(
completer.complete

View File

@ -106,7 +106,7 @@ class DDPM(pl.LightningModule):
], 'currently only supporting "eps" and "x0"'
self.parameterization = parameterization
print(
f'{self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
f' >> {self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
)
self.cond_stage_model = None
self.clip_denoised = clip_denoised

View File

@ -245,7 +245,7 @@ class AttnBlock(nn.Module):
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
print(f" >> Making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
return AttnBlock(in_channels)
elif attn_type == "none":
@ -521,7 +521,7 @@ class Decoder(nn.Module):
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
print(" >> Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in

View File

@ -75,7 +75,7 @@ def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
f' >> {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
)
return total_params

View File

@ -16,8 +16,6 @@ from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log
from omegaconf import OmegaConf
from backend.invoke_ai_web_server import InvokeAIWebServer
def main():
"""Initialize command-line parsers and the diffusion model"""
@ -33,7 +31,7 @@ def main():
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
sys.exit(-1)
print('* Initializing, be patient...\n')
print('* Initializing, be patient...')
from ldm.generate import Generate
# these two lines prevent a horrible warning message from appearing
@ -42,45 +40,7 @@ def main():
transformers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules
try:
gfpgan, codeformer, esrgan = None, None, None
if opt.restore or opt.esrgan:
from ldm.invoke.restoration import Restoration
restoration = Restoration()
if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_dir, opt.gfpgan_model_path)
else:
print('>> Face restoration disabled')
if opt.esrgan:
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
else:
print('>> Upscaling disabled')
else:
print('>> Face restoration and upscaling disabled')
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
# creating a simple text2image object with a handful of
# defaults passed on the command line.
# additional parameters will be added (or overriden) during
# the user input loop
try:
gen = Generate(
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
embedding_path = opt.embedding_path,
full_precision = opt.full_precision,
precision = opt.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem,
)
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
gfpgan,codeformer,esrgan = load_face_restoration(opt)
# make sure the output directory exists
if not os.path.exists(opt.outdir):
@ -100,6 +60,24 @@ def main():
print(f'{e}. Aborting.')
sys.exit(-1)
# creating a Generate object:
try:
gen = Generate(
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
embedding_path = opt.embedding_path,
full_precision = opt.full_precision,
precision = opt.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem,
)
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
if opt.seamless:
print(">> changed to seamless tiling mode")
@ -124,12 +102,12 @@ def main_loop(gen, opt, infile):
done = False
path_filter = re.compile(r'[<>:"/\\|?*]')
last_results = list()
model_config = OmegaConf.load(opt.conf)[opt.model]
model_config = OmegaConf.load(opt.conf)
# The readline completer reads history from the .dream_history file located in the
# output directory specified at the time of script launch. We do not currently support
# changing the history file midstream when the output directory is changed.
completer = get_completer(opt)
completer = get_completer(opt, models=list(model_config.keys()))
output_cntr = completer.get_current_history_length()+1
# os.pathconf is not available on Windows
@ -173,6 +151,16 @@ def main_loop(gen, opt, infile):
command = command.replace('!fix ','',1)
operation = 'postprocess'
elif subcommand.startswith('switch'):
model_name = command.replace('!switch ','',1)
gen.set_model(model_name)
continue
elif subcommand.startswith('models'):
model_name = command.replace('!models ','',1)
gen.model_cache.print_models()
continue
elif subcommand.startswith('fetch'):
file_path = command.replace('!fetch ','',1)
retrieve_dream_command(opt,file_path,completer)
@ -218,9 +206,9 @@ def main_loop(gen, opt, infile):
# width and height are set by model if not specified
if not opt.width:
opt.width = model_config.width
opt.width = gen.width
if not opt.height:
opt.height = model_config.height
opt.height = gen.height
# retrieve previous value of init image if requested
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
@ -509,6 +497,7 @@ def get_next_command(infile=None) -> str: # command string
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
print('\n* --web was specified, starting web server...')
from backend.invoke_ai_web_server import InvokeAIWebServer
# Change working directory to the stable-diffusion directory
os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
@ -547,6 +536,28 @@ def split_variations(variations_string) -> list:
else:
return parts
def load_face_restoration(opt):
try:
gfpgan, codeformer, esrgan = None, None, None
if opt.restore or opt.esrgan:
from ldm.invoke.restoration import Restoration
restoration = Restoration()
if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_dir, opt.gfpgan_model_path)
else:
print('>> Face restoration disabled')
if opt.esrgan:
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
else:
print('>> Upscaling disabled')
else:
print('>> Face restoration and upscaling disabled')
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
return gfpgan,codeformer,esrgan
def retrieve_dream_command(opt,file_path,completer):
'''
Given a full or partial path to a previously-generated image file,