enable fast switching between models in invoke.py

- This PR enables two new commands in the invoke.py script

 !models         -- list the available models and their cache status
 !switch <model> -- switch to the indicated model

Example:

 invoke> !models
   laion400m            not loaded  Latent Diffusion LAION400M model
   stable-diffusion-1.4     active  Stable Diffusion inference model version 1.4
   waifu-1.3                cached  Waifu anime model version 1.3
 invoke> !switch waifu-1.3
   >> Caching model stable-diffusion-1.4 in system RAM
   >> Retrieving model waifu-1.3 from system RAM cache

More details:

- Use fast switching algorithm described in PR #948
- Models are selected using their configuration stanza name
  given in models.yaml.
- To avoid filling up CPU RAM with cached models, this PR
  implements an LRU cache that monitors available CPU RAM.
- The caching code allows the minimum value of available RAM
  to be adjusted, but invoke.py does not currently have a
  command-line argument that allows you to set it. The
  minimum free RAM is arbitrarily set to 2 GB.
- Add optional description field to configs/models.yaml

Unrelated fixes:
- Added ">>" to CompViz model loading messages in order to make user experience
  more consistent.
- When generating an image greater than defaults, will only warn about possible
  VRAM filling the first time.
- Fixed bug that was causing help message to be printed twice. This involved
  moving the import line for the web backend into the section where it is
  called.
This commit is contained in:
Lincoln Stein 2022-10-12 02:14:59 -04:00
parent b9e910b5f4
commit 19341e95a6
7 changed files with 219 additions and 191 deletions

View File

@ -33,6 +33,7 @@ from ldm.invoke.args import metadata_from_png
from ldm.invoke.image_util import InitImageResizer from ldm.invoke.image_util import InitImageResizer
from ldm.invoke.devices import choose_torch_device, choose_precision from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.conditioning import get_uc_and_c from ldm.invoke.conditioning import get_uc_and_c
from ldm.invoke.model_cache import ModelCache
"""Simplified text to image API for stable diffusion/latent diffusion """Simplified text to image API for stable diffusion/latent diffusion
@ -123,12 +124,11 @@ class Generate:
esrgan=None, esrgan=None,
free_gpu_mem=False, free_gpu_mem=False,
): ):
models = OmegaConf.load(conf) mconfig = OmegaConf.load(conf)
mconfig = models[model] self.model_name = model
self.weights = mconfig.weights if weights is None else weights self.height = None
self.config = mconfig.config if config is None else config self.width = None
self.height = mconfig.height self.model_cache = None
self.width = mconfig.width
self.iterations = 1 self.iterations = 1
self.steps = 50 self.steps = 50
self.cfg_scale = 7.5 self.cfg_scale = 7.5
@ -139,6 +139,7 @@ class Generate:
self.seamless = False self.seamless = False
self.embedding_path = embedding_path self.embedding_path = embedding_path
self.model = None # empty for now self.model = None # empty for now
self.model_hash = None
self.sampler = None self.sampler = None
self.device = None self.device = None
self.session_peakmem = None self.session_peakmem = None
@ -149,6 +150,7 @@ class Generate:
self.codeformer = codeformer self.codeformer = codeformer
self.esrgan = esrgan self.esrgan = esrgan
self.free_gpu_mem = free_gpu_mem self.free_gpu_mem = free_gpu_mem
self.size_matters = True # used to warn once about large image sizes and VRAM
# Note that in previous versions, there was an option to pass the # Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so # device to Generate(). However the device was then ignored, so
@ -164,6 +166,9 @@ class Generate:
if self.precision == 'auto': if self.precision == 'auto':
self.precision = choose_precision(self.device) self.precision = choose_precision(self.device)
# model caching system for fast switching
self.model_cache = ModelCache(mconfig,self.device,self.precision)
# for VRAM usage statistics # for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -294,7 +299,12 @@ class Generate:
with_variations = [] if with_variations is None else with_variations with_variations = [] if with_variations is None else with_variations
# will instantiate the model or return it from cache # will instantiate the model or return it from cache
model = self.load_model() model = self.set_model(self.model_name)
# self.width and self.height are set by set_model()
# to the width and height of the image training set
width = width or self.width
height = height or self.height
for m in model.modules(): for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
@ -584,8 +594,9 @@ class Generate:
# this returns a torch tensor # this returns a torch tensor
init_mask = self._create_init_mask(image, width, height, fit=fit) init_mask = self._create_init_mask(image, width, height, fit=fit)
if (image.width * image.height) > (self.width * self.height): if (image.width * image.height) > (self.width * self.height) and self.size_matters:
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
self.size_matters = False
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
@ -635,29 +646,47 @@ class Generate:
return self.generators['inpaint'] return self.generators['inpaint']
def load_model(self): def load_model(self):
"""Load and initialize the model from configuration variables passed at object creation time""" '''
if self.model is None: preload model identified in self.model_name
seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) '''
try: self.set_model(self.model_name)
model = self._load_model_from_config(self.config, self.weights)
if self.embedding_path is not None:
model.embedding_manager.load(
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
)
self.model = model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
self.model.cond_stage_model.device = self.device
except AttributeError as e:
print(f'>> Error loading model. {str(e)}', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
raise SystemExit from e
self._set_sampler() def set_model(self,model_name):
"""
Given the name of a model defined in models.yaml, will load and initialize it
and return the model object. Previously-used models will be cached.
"""
if self.model_name == model_name and self.model is not None:
return self.model
for m in self.model.modules(): model_data = self.model_cache.get_model(model_name)
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): if model_data is None or len(model_data) == 0:
m._orig_padding_mode = m.padding_mode print(f'** Model switch failed **')
return self.model
self.model = model_data['model']
self.width = model_data['width']
self.height= model_data['height']
self.model_hash = model_data['hash']
# uncache generators so they pick up new models
self.generators = {}
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
if self.embedding_path is not None:
model.embedding_manager.load(
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
self.model.cond_stage_model.device = self.device
self._set_sampler()
for m in self.model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m._orig_padding_mode = m.padding_mode
self.model_name = model_name
return self.model return self.model
def correct_colors(self, def correct_colors(self,
@ -761,53 +790,6 @@ class Generate:
print(msg) print(msg)
# Be warned: config is the path to the model config file, not the invoke conf file!
# Also note that we can get config and weights from self, so why do we need to
# pass them as args?
def _load_model_from_config(self, config, weights):
print(f'>> Loading model from {weights}')
# for usage statistics
device_type = choose_torch_device()
if device_type == 'cuda':
torch.cuda.reset_peak_memory_stats()
tic = time.time()
# this does the work
c = OmegaConf.load(config)
with open(weights,'rb') as f:
weight_bytes = f.read()
self.model_hash = self._cached_sha256(weights,weight_bytes)
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
del weight_bytes
sd = pl_sd['state_dict']
model = instantiate_from_config(c.model)
m, u = model.load_state_dict(sd, strict=False)
if self.precision == 'float16':
print('>> Using faster float16 precision')
model.to(torch.float16)
else:
print('>> Using more accurate float32 precision')
model.to(self.device)
model.eval()
# usage statistics
toc = time.time()
print(
f'>> Model loaded in', '%4.2fs' % (toc - tic)
)
if self._has_cuda():
print(
'>> Max VRAM used to load the model:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
'\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
)
return model
def _load_img(self, img, width, height)->Image: def _load_img(self, img, width, height)->Image:
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
image = img image = img
@ -951,26 +933,6 @@ class Generate:
def _has_cuda(self): def _has_cuda(self):
return self.device.type == 'cuda' return self.device.type == 'cuda'
def _cached_sha256(self,path,data):
dirname = os.path.dirname(path)
basename = os.path.basename(path)
base, _ = os.path.splitext(basename)
hashpath = os.path.join(dirname,base+'.sha256')
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
with open(hashpath) as f:
hash = f.read()
return hash
print(f'>> Calculating sha256 hash of weights file')
tic = time.time()
sha = hashlib.sha256()
sha.update(data)
hash = sha.hexdigest()
toc = time.time()
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
with open(hashpath,'w') as f:
f.write(hash)
return hash
def write_intermediate_images(self,modulus,path): def write_intermediate_images(self,modulus,path):
counter = -1 counter = -1
if not os.path.exists(path): if not os.path.exists(path):

View File

@ -20,22 +20,35 @@ from ldm.util import instantiate_from_config
GIGS=2**30 GIGS=2**30
AVG_MODEL_SIZE=2.1*GIGS AVG_MODEL_SIZE=2.1*GIGS
DEFAULT_MIN_AVAIL=2*GIGS
class ModelCache(object): class ModelCache(object):
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_free_mem=2*GIGS): def __init__(self, config:OmegaConf, device_type:str, precision:str, min_avail_mem=DEFAULT_MIN_AVAIL):
'''
Initialize with the path to the models.yaml config file,
the torch device type, and precision. The optional
min_avail_mem argument specifies how much unused system
(CPU) memory to preserve. The cache of models in RAM will
grow until this value is approached. Default is 2G.
'''
# prevent nasty-looking CLIP log message # prevent nasty-looking CLIP log message
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
self.config = config self.config = config
self.precision = precision self.precision = precision
self.device = torch.device(device_type) self.device = torch.device(device_type)
self.min_free_mem = min_free_mem self.min_avail_mem = min_avail_mem
self.models = {} self.models = {}
self.stack = [] # this is an LRU FIFO self.stack = [] # this is an LRU FIFO
self.current_model = None self.current_model = None
def get_model(self, model_name:str): def get_model(self, model_name:str):
'''
Given a model named identified in models.yaml, return
the model object. If in RAM will load into GPU VRAM.
If on disk, will load from there.
'''
if model_name not in self.config: if model_name not in self.config:
print(f'"{model_name}" is not a known model name. Please check your models.yaml file') print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
return None return None
if self.current_model != model_name: if self.current_model != model_name:
@ -43,22 +56,42 @@ class ModelCache(object):
if model_name in self.models: if model_name in self.models:
requested_model = self.models[model_name]['model'] requested_model = self.models[model_name]['model']
self._model_from_cpu(requested_model) print(f'>> Retrieving model {model_name} from system RAM cache')
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
width = self.models[model_name]['width'] width = self.models[model_name]['width']
height = self.models[model_name]['height'] height = self.models[model_name]['height']
hash = self.models[model_name]['hash']
else: else:
self._check_memory() self._check_memory()
requested_model, width, height = self._load_model(model_name) try:
self.models[model_name] = {} requested_model, width, height, hash = self._load_model(model_name)
self.models[model_name]['model'] = requested_model self.models[model_name] = {}
self.models[model_name]['width'] = width self.models[model_name]['model'] = requested_model
self.models[model_name]['height'] = height self.models[model_name]['width'] = width
self.models[model_name]['height'] = height
self.models[model_name]['hash'] = hash
except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}')
return {}
self.current_model = model_name self.current_model = model_name
self._push_newest_model(model_name) self._push_newest_model(model_name)
return requested_model, width, height return {
'model':requested_model,
'width':width,
'height':height,
'hash': hash
}
def list_models(self): def list_models(self) -> dict:
'''
Return a dict of models in the format:
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
'description': description,
},
model_name2: { etc }
'''
result = {}
for name in self.config: for name in self.config:
try: try:
description = self.config[name].description description = self.config[name].description
@ -70,28 +103,26 @@ class ModelCache(object):
status = 'cached' status = 'cached'
else: else:
status = 'not loaded' status = 'not loaded'
print(f'{name:20s} {status:>10s} {description}') result[name]={}
result[name]['status']=status
result[name]['description']=description
return result
def print_models(self):
'''
Print a table of models, their descriptions, and load status
'''
models = self.list_models()
for name in models:
print(f'{name:20s} {models[name]["status"]:>10s} {models[name]["description"]}')
def _check_memory(self): def _check_memory(self):
free_memory = psutil.virtual_memory()[4] avail_memory = psutil.virtual_memory()[1]
print(f'DEBUG: free memory = {free_memory}, min_mem = {self.min_free_mem}') if avail_memory + AVG_MODEL_SIZE < self.min_avail_mem:
while free_memory + AVG_MODEL_SIZE < self.min_free_mem:
print(f'DEBUG: free memory = {free_memory}')
least_recent_model = self._pop_oldest_model() least_recent_model = self._pop_oldest_model()
if least_recent_model is None: if least_recent_model is not None:
return del self.models[least_recent_model]
gc.collect()
print(f'DEBUG: clearing {least_recent_model} from cache (refcount = {getrefcount(self.models[least_recent_model]["model"])})')
del self.models[least_recent_model]['model']
gc.collect()
new_free_memory = psutil.virtual_memory()[4]
if new_free_memory <= free_memory:
print(f'>> **Unable to free memory for model caching.**')
break;
free_memory = new_free_memory
def _load_model(self, model_name:str): def _load_model(self, model_name:str):
@ -106,18 +137,20 @@ class ModelCache(object):
width = mconfig.width width = mconfig.width
height = mconfig.height height = mconfig.height
print(f'>> Loading {model_name} weights from {weights}') print(f'>> Loading {model_name} from {weights}')
# for usage statistics # for usage statistics
if self._has_cuda(): if self._has_cuda():
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
tic = time.time() tic = time.time()
# this does the work # this does the work
c = OmegaConf.load(config) c = OmegaConf.load(config)
with open(weights,'rb') as f: with open(weights,'rb') as f:
weight_bytes = f.read() weight_bytes = f.read()
self.model_hash = self._cached_sha256(weights,weight_bytes) model_hash = self._cached_sha256(weights,weight_bytes)
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu') pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
del weight_bytes del weight_bytes
sd = pl_sd['state_dict'] sd = pl_sd['state_dict']
@ -143,39 +176,40 @@ class ModelCache(object):
'\n>> Current VRAM usage:' '\n>> Current VRAM usage:'
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9), '%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
) )
return model, width, height return model, width, height, model_hash
def unload_model(self, model_name:str): def unload_model(self, model_name:str):
if model_name not in self.models: if model_name not in self.models:
return return
print(f'>> Unloading model {model_name}') print(f'>> Caching model {model_name} in system RAM')
model = self.models[model_name]['model'] model = self.models[model_name]['model']
self._model_to_cpu(model) self.models[model_name]['model'] = self._model_to_cpu(model)
gc.collect() gc.collect()
if self._has_cuda(): if self._has_cuda():
torch.cuda.empty_cache() torch.cuda.empty_cache()
def _model_to_cpu(self,model): def _model_to_cpu(self,model):
if self._has_cuda(): if self._has_cuda():
print(f'DEBUG: moving model to cpu')
model.first_stage_model.to('cpu') model.first_stage_model.to('cpu')
model.cond_stage_model.to('cpu') model.cond_stage_model.to('cpu')
model.model.to('cpu') model.model.to('cpu')
return model.to('cpu')
def _model_from_cpu(self,model): def _model_from_cpu(self,model):
if self._has_cuda(): if self._has_cuda():
print(f'DEBUG: moving model into {self.device.type}')
model.to(self.device) model.to(self.device)
model.first_stage_model.to(self.device) model.first_stage_model.to(self.device)
model.cond_stage_model.to(self.device) model.cond_stage_model.to(self.device)
return model
def _pop_oldest_model(self): def _pop_oldest_model(self):
''' '''
Remove the first element of the FIFO, which ought Remove the first element of the FIFO, which ought
to be the least recently accessed model. to be the least recently accessed model. Do not
pop the last one, because it is in active use!
''' '''
if len(self.stack)>0: if len(self.stack) > 1:
self.stack.pop(0) return self.stack.pop(0)
def _push_newest_model(self,model_name:str): def _push_newest_model(self,model_name:str):
''' '''
@ -187,7 +221,6 @@ class ModelCache(object):
except ValueError: except ValueError:
pass pass
self.stack.append(model_name) self.stack.append(model_name)
print(f'DEBUG, stack={self.stack}')
def _has_cuda(self): def _has_cuda(self):
return self.device.type == 'cuda' return self.device.type == 'cuda'

View File

@ -47,7 +47,10 @@ COMMANDS = (
'--skip_normalize','-x', '--skip_normalize','-x',
'--log_tokenization','-t', '--log_tokenization','-t',
'--hires_fix', '--hires_fix',
'!fix','!fetch','!history','!search','!clear', '!fix','!fetch','!history','!search','!clear','!models','!switch',
)
MODEL_COMMANDS = (
'!switch',
) )
IMG_PATH_COMMANDS = ( IMG_PATH_COMMANDS = (
'--outdir[=\s]', '--outdir[=\s]',
@ -63,8 +66,9 @@ IMG_FILE_COMMANDS=(
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$' path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
class Completer(object): class Completer(object):
def __init__(self, options): def __init__(self, options, models=[]):
self.options = sorted(options) self.options = sorted(options)
self.models = sorted(models)
self.seeds = set() self.seeds = set()
self.matches = list() self.matches = list()
self.default_dir = None self.default_dir = None
@ -88,6 +92,9 @@ class Completer(object):
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer): elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
self.matches= self._seed_completions(text,state) self.matches= self._seed_completions(text,state)
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
self.matches= self._model_completions(text,state)
# This is the first time for this text, so build a match list. # This is the first time for this text, so build a match list.
elif text: elif text:
self.matches = [ self.matches = [
@ -188,6 +195,21 @@ class Completer(object):
matches.sort() matches.sort()
return matches return matches
def _model_completions(self, text, state):
m = re.search('(!switch\s+)(\w*)',text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else:
switch = ''
partial = text
matches = list()
for s in self.models:
if s.startswith(partial):
matches.append(switch+s)
matches.sort()
return matches
def _pre_input_hook(self): def _pre_input_hook(self):
if self.linebuffer: if self.linebuffer:
readline.insert_text(self.linebuffer) readline.insert_text(self.linebuffer)
@ -266,9 +288,9 @@ class DummyCompleter(Completer):
def set_line(self,line): def set_line(self,line):
print(f'# {line}') print(f'# {line}')
def get_completer(opt:Args)->Completer: def get_completer(opt:Args, models=[])->Completer:
if readline_available: if readline_available:
completer = Completer(COMMANDS) completer = Completer(COMMANDS,models)
readline.set_completer( readline.set_completer(
completer.complete completer.complete

View File

@ -106,7 +106,7 @@ class DDPM(pl.LightningModule):
], 'currently only supporting "eps" and "x0"' ], 'currently only supporting "eps" and "x0"'
self.parameterization = parameterization self.parameterization = parameterization
print( print(
f'{self.__class__.__name__}: Running in {self.parameterization}-prediction mode' f' >> {self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
) )
self.cond_stage_model = None self.cond_stage_model = None
self.clip_denoised = clip_denoised self.clip_denoised = clip_denoised

View File

@ -245,7 +245,7 @@ class AttnBlock(nn.Module):
def make_attn(in_channels, attn_type="vanilla"): def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels") print(f" >> Making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla": if attn_type == "vanilla":
return AttnBlock(in_channels) return AttnBlock(in_channels)
elif attn_type == "none": elif attn_type == "none":
@ -521,7 +521,7 @@ class Decoder(nn.Module):
block_in = ch*ch_mult[self.num_resolutions-1] block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1) curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res) self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format( print(" >> Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape))) self.z_shape, np.prod(self.z_shape)))
# z to block_in # z to block_in

View File

@ -75,7 +75,7 @@ def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters()) total_params = sum(p.numel() for p in model.parameters())
if verbose: if verbose:
print( print(
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.' f' >> {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
) )
return total_params return total_params

View File

@ -16,8 +16,6 @@ from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.image_util import make_grid from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log from ldm.invoke.log import write_log
from omegaconf import OmegaConf from omegaconf import OmegaConf
from backend.invoke_ai_web_server import InvokeAIWebServer
def main(): def main():
"""Initialize command-line parsers and the diffusion model""" """Initialize command-line parsers and the diffusion model"""
@ -33,7 +31,7 @@ def main():
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.') print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
sys.exit(-1) sys.exit(-1)
print('* Initializing, be patient...\n') print('* Initializing, be patient...')
from ldm.generate import Generate from ldm.generate import Generate
# these two lines prevent a horrible warning message from appearing # these two lines prevent a horrible warning message from appearing
@ -42,45 +40,7 @@ def main():
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules # Loading Face Restoration and ESRGAN Modules
try: gfpgan,codeformer,esrgan = load_face_restoration(opt)
gfpgan, codeformer, esrgan = None, None, None
if opt.restore or opt.esrgan:
from ldm.invoke.restoration import Restoration
restoration = Restoration()
if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_dir, opt.gfpgan_model_path)
else:
print('>> Face restoration disabled')
if opt.esrgan:
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
else:
print('>> Upscaling disabled')
else:
print('>> Face restoration and upscaling disabled')
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
# creating a simple text2image object with a handful of
# defaults passed on the command line.
# additional parameters will be added (or overriden) during
# the user input loop
try:
gen = Generate(
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
embedding_path = opt.embedding_path,
full_precision = opt.full_precision,
precision = opt.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem,
)
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
# make sure the output directory exists # make sure the output directory exists
if not os.path.exists(opt.outdir): if not os.path.exists(opt.outdir):
@ -100,6 +60,24 @@ def main():
print(f'{e}. Aborting.') print(f'{e}. Aborting.')
sys.exit(-1) sys.exit(-1)
# creating a Generate object:
try:
gen = Generate(
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
embedding_path = opt.embedding_path,
full_precision = opt.full_precision,
precision = opt.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem,
)
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
if opt.seamless: if opt.seamless:
print(">> changed to seamless tiling mode") print(">> changed to seamless tiling mode")
@ -124,12 +102,12 @@ def main_loop(gen, opt, infile):
done = False done = False
path_filter = re.compile(r'[<>:"/\\|?*]') path_filter = re.compile(r'[<>:"/\\|?*]')
last_results = list() last_results = list()
model_config = OmegaConf.load(opt.conf)[opt.model] model_config = OmegaConf.load(opt.conf)
# The readline completer reads history from the .dream_history file located in the # The readline completer reads history from the .dream_history file located in the
# output directory specified at the time of script launch. We do not currently support # output directory specified at the time of script launch. We do not currently support
# changing the history file midstream when the output directory is changed. # changing the history file midstream when the output directory is changed.
completer = get_completer(opt) completer = get_completer(opt, models=list(model_config.keys()))
output_cntr = completer.get_current_history_length()+1 output_cntr = completer.get_current_history_length()+1
# os.pathconf is not available on Windows # os.pathconf is not available on Windows
@ -173,6 +151,16 @@ def main_loop(gen, opt, infile):
command = command.replace('!fix ','',1) command = command.replace('!fix ','',1)
operation = 'postprocess' operation = 'postprocess'
elif subcommand.startswith('switch'):
model_name = command.replace('!switch ','',1)
gen.set_model(model_name)
continue
elif subcommand.startswith('models'):
model_name = command.replace('!models ','',1)
gen.model_cache.print_models()
continue
elif subcommand.startswith('fetch'): elif subcommand.startswith('fetch'):
file_path = command.replace('!fetch ','',1) file_path = command.replace('!fetch ','',1)
retrieve_dream_command(opt,file_path,completer) retrieve_dream_command(opt,file_path,completer)
@ -218,9 +206,9 @@ def main_loop(gen, opt, infile):
# width and height are set by model if not specified # width and height are set by model if not specified
if not opt.width: if not opt.width:
opt.width = model_config.width opt.width = gen.width
if not opt.height: if not opt.height:
opt.height = model_config.height opt.height = gen.height
# retrieve previous value of init image if requested # retrieve previous value of init image if requested
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img): if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
@ -509,6 +497,7 @@ def get_next_command(infile=None) -> str: # command string
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan): def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
print('\n* --web was specified, starting web server...') print('\n* --web was specified, starting web server...')
from backend.invoke_ai_web_server import InvokeAIWebServer
# Change working directory to the stable-diffusion directory # Change working directory to the stable-diffusion directory
os.chdir( os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
@ -547,6 +536,28 @@ def split_variations(variations_string) -> list:
else: else:
return parts return parts
def load_face_restoration(opt):
try:
gfpgan, codeformer, esrgan = None, None, None
if opt.restore or opt.esrgan:
from ldm.invoke.restoration import Restoration
restoration = Restoration()
if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_dir, opt.gfpgan_model_path)
else:
print('>> Face restoration disabled')
if opt.esrgan:
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
else:
print('>> Upscaling disabled')
else:
print('>> Face restoration and upscaling disabled')
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
return gfpgan,codeformer,esrgan
def retrieve_dream_command(opt,file_path,completer): def retrieve_dream_command(opt,file_path,completer):
''' '''
Given a full or partial path to a previously-generated image file, Given a full or partial path to a previously-generated image file,