mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
enable fast switching between models in invoke.py
- This PR enables two new commands in the invoke.py script !models -- list the available models and their cache status !switch <model> -- switch to the indicated model Example: invoke> !models laion400m not loaded Latent Diffusion LAION400M model stable-diffusion-1.4 active Stable Diffusion inference model version 1.4 waifu-1.3 cached Waifu anime model version 1.3 invoke> !switch waifu-1.3 >> Caching model stable-diffusion-1.4 in system RAM >> Retrieving model waifu-1.3 from system RAM cache The name and descriptions of the models are taken from `config/models.yaml`. A future enhancement to `model_cache.py` will be to enable new model stanzas to be added to the file programmatically. This will be useful for the WebGUI. More details: - Use fast switching algorithm described in PR #948 - Models are selected using their configuration stanza name given in models.yaml. - To avoid filling up CPU RAM with cached models, this PR implements an LRU cache that monitors available CPU RAM. - The caching code allows the minimum value of available RAM to be adjusted, but invoke.py does not currently have a command-line argument that allows you to set it. The minimum free RAM is arbitrarily set to 2 GB. - Add optional description field to configs/models.yaml Unrelated fixes: - Added ">>" to CompViz model loading messages in order to make user experience more consistent. - When generating an image greater than defaults, will only warn about possible VRAM filling the first time. - Fixed bug that was causing help message to be printed twice. This involved moving the import line for the web backend into the section where it is called. Coauthored by: @ArDiouscuros
This commit is contained in:
parent
b9e910b5f4
commit
488334710b
152
ldm/generate.py
152
ldm/generate.py
@ -33,6 +33,7 @@ from ldm.invoke.args import metadata_from_png
|
||||
from ldm.invoke.image_util import InitImageResizer
|
||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||
from ldm.invoke.conditioning import get_uc_and_c
|
||||
from ldm.invoke.model_cache import ModelCache
|
||||
|
||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||
|
||||
@ -123,12 +124,11 @@ class Generate:
|
||||
esrgan=None,
|
||||
free_gpu_mem=False,
|
||||
):
|
||||
models = OmegaConf.load(conf)
|
||||
mconfig = models[model]
|
||||
self.weights = mconfig.weights if weights is None else weights
|
||||
self.config = mconfig.config if config is None else config
|
||||
self.height = mconfig.height
|
||||
self.width = mconfig.width
|
||||
mconfig = OmegaConf.load(conf)
|
||||
self.model_name = model
|
||||
self.height = None
|
||||
self.width = None
|
||||
self.model_cache = None
|
||||
self.iterations = 1
|
||||
self.steps = 50
|
||||
self.cfg_scale = 7.5
|
||||
@ -139,6 +139,7 @@ class Generate:
|
||||
self.seamless = False
|
||||
self.embedding_path = embedding_path
|
||||
self.model = None # empty for now
|
||||
self.model_hash = None
|
||||
self.sampler = None
|
||||
self.device = None
|
||||
self.session_peakmem = None
|
||||
@ -149,6 +150,7 @@ class Generate:
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
@ -164,6 +166,9 @@ class Generate:
|
||||
if self.precision == 'auto':
|
||||
self.precision = choose_precision(self.device)
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||
|
||||
# for VRAM usage statistics
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||
transformers.logging.set_verbosity_error()
|
||||
@ -294,7 +299,12 @@ class Generate:
|
||||
with_variations = [] if with_variations is None else with_variations
|
||||
|
||||
# will instantiate the model or return it from cache
|
||||
model = self.load_model()
|
||||
model = self.set_model(self.model_name)
|
||||
|
||||
# self.width and self.height are set by set_model()
|
||||
# to the width and height of the image training set
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
@ -584,8 +594,9 @@ class Generate:
|
||||
# this returns a torch tensor
|
||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||
|
||||
if (image.width * image.height) > (self.width * self.height):
|
||||
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||
self.size_matters = False
|
||||
|
||||
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
||||
|
||||
@ -635,29 +646,47 @@ class Generate:
|
||||
return self.generators['inpaint']
|
||||
|
||||
def load_model(self):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if self.model is None:
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
try:
|
||||
model = self._load_model_from_config(self.config, self.weights)
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(
|
||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
||||
)
|
||||
self.model = model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
self.model.cond_stage_model.device = self.device
|
||||
except AttributeError as e:
|
||||
print(f'>> Error loading model. {str(e)}', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
raise SystemExit from e
|
||||
'''
|
||||
preload model identified in self.model_name
|
||||
'''
|
||||
self.set_model(self.model_name)
|
||||
|
||||
self._set_sampler()
|
||||
def set_model(self,model_name):
|
||||
"""
|
||||
Given the name of a model defined in models.yaml, will load and initialize it
|
||||
and return the model object. Previously-used models will be cached.
|
||||
"""
|
||||
if self.model_name == model_name and self.model is not None:
|
||||
return self.model
|
||||
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m._orig_padding_mode = m.padding_mode
|
||||
model_data = self.model_cache.get_model(model_name)
|
||||
if model_data is None or len(model_data) == 0:
|
||||
print(f'** Model switch failed **')
|
||||
return self.model
|
||||
|
||||
self.model = model_data['model']
|
||||
self.width = model_data['width']
|
||||
self.height= model_data['height']
|
||||
self.model_hash = model_data['hash']
|
||||
|
||||
# uncache generators so they pick up new models
|
||||
self.generators = {}
|
||||
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(
|
||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
||||
)
|
||||
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
self.model.cond_stage_model.device = self.device
|
||||
self._set_sampler()
|
||||
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m._orig_padding_mode = m.padding_mode
|
||||
|
||||
self.model_name = model_name
|
||||
return self.model
|
||||
|
||||
def correct_colors(self,
|
||||
@ -761,53 +790,6 @@ class Generate:
|
||||
|
||||
print(msg)
|
||||
|
||||
# Be warned: config is the path to the model config file, not the invoke conf file!
|
||||
# Also note that we can get config and weights from self, so why do we need to
|
||||
# pass them as args?
|
||||
def _load_model_from_config(self, config, weights):
|
||||
print(f'>> Loading model from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
device_type = choose_torch_device()
|
||||
if device_type == 'cuda':
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
c = OmegaConf.load(config)
|
||||
with open(weights,'rb') as f:
|
||||
weight_bytes = f.read()
|
||||
self.model_hash = self._cached_sha256(weights,weight_bytes)
|
||||
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||
del weight_bytes
|
||||
sd = pl_sd['state_dict']
|
||||
model = instantiate_from_config(c.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if self.precision == 'float16':
|
||||
print('>> Using faster float16 precision')
|
||||
model.to(torch.float16)
|
||||
else:
|
||||
print('>> Using more accurate float32 precision')
|
||||
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(
|
||||
f'>> Model loaded in', '%4.2fs' % (toc - tic)
|
||||
)
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def _load_img(self, img, width, height)->Image:
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
@ -951,26 +933,6 @@ class Generate:
|
||||
def _has_cuda(self):
|
||||
return self.device.type == 'cuda'
|
||||
|
||||
def _cached_sha256(self,path,data):
|
||||
dirname = os.path.dirname(path)
|
||||
basename = os.path.basename(path)
|
||||
base, _ = os.path.splitext(basename)
|
||||
hashpath = os.path.join(dirname,base+'.sha256')
|
||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
print(f'>> Calculating sha256 hash of weights file')
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
sha.update(data)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
||||
with open(hashpath,'w') as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
def write_intermediate_images(self,modulus,path):
|
||||
counter = -1
|
||||
if not os.path.exists(path):
|
||||
|
@ -20,22 +20,35 @@ from ldm.util import instantiate_from_config
|
||||
|
||||
GIGS=2**30
|
||||
AVG_MODEL_SIZE=2.1*GIGS
|
||||
DEFAULT_MIN_AVAIL=2*GIGS
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_free_mem=2*GIGS):
|
||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_avail_mem=DEFAULT_MIN_AVAIL):
|
||||
'''
|
||||
Initialize with the path to the models.yaml config file,
|
||||
the torch device type, and precision. The optional
|
||||
min_avail_mem argument specifies how much unused system
|
||||
(CPU) memory to preserve. The cache of models in RAM will
|
||||
grow until this value is approached. Default is 2G.
|
||||
'''
|
||||
# prevent nasty-looking CLIP log message
|
||||
transformers.logging.set_verbosity_error()
|
||||
self.config = config
|
||||
self.precision = precision
|
||||
self.device = torch.device(device_type)
|
||||
self.min_free_mem = min_free_mem
|
||||
self.min_avail_mem = min_avail_mem
|
||||
self.models = {}
|
||||
self.stack = [] # this is an LRU FIFO
|
||||
self.current_model = None
|
||||
|
||||
def get_model(self, model_name:str):
|
||||
'''
|
||||
Given a model named identified in models.yaml, return
|
||||
the model object. If in RAM will load into GPU VRAM.
|
||||
If on disk, will load from there.
|
||||
'''
|
||||
if model_name not in self.config:
|
||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return None
|
||||
|
||||
if self.current_model != model_name:
|
||||
@ -43,22 +56,42 @@ class ModelCache(object):
|
||||
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]['model']
|
||||
self._model_from_cpu(requested_model)
|
||||
print(f'>> Retrieving model {model_name} from system RAM cache')
|
||||
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
|
||||
width = self.models[model_name]['width']
|
||||
height = self.models[model_name]['height']
|
||||
hash = self.models[model_name]['hash']
|
||||
else:
|
||||
self._check_memory()
|
||||
requested_model, width, height = self._load_model(model_name)
|
||||
self.models[model_name] = {}
|
||||
self.models[model_name]['model'] = requested_model
|
||||
self.models[model_name]['width'] = width
|
||||
self.models[model_name]['height'] = height
|
||||
try:
|
||||
requested_model, width, height, hash = self._load_model(model_name)
|
||||
self.models[model_name] = {}
|
||||
self.models[model_name]['model'] = requested_model
|
||||
self.models[model_name]['width'] = width
|
||||
self.models[model_name]['height'] = height
|
||||
self.models[model_name]['hash'] = hash
|
||||
except Exception as e:
|
||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||
return {}
|
||||
|
||||
self.current_model = model_name
|
||||
self._push_newest_model(model_name)
|
||||
return requested_model, width, height
|
||||
return {
|
||||
'model':requested_model,
|
||||
'width':width,
|
||||
'height':height,
|
||||
'hash': hash
|
||||
}
|
||||
|
||||
def list_models(self):
|
||||
def list_models(self) -> dict:
|
||||
'''
|
||||
Return a dict of models in the format:
|
||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
||||
'description': description,
|
||||
},
|
||||
model_name2: { etc }
|
||||
'''
|
||||
result = {}
|
||||
for name in self.config:
|
||||
try:
|
||||
description = self.config[name].description
|
||||
@ -70,28 +103,26 @@ class ModelCache(object):
|
||||
status = 'cached'
|
||||
else:
|
||||
status = 'not loaded'
|
||||
print(f'{name:20s} {status:>10s} {description}')
|
||||
result[name]={}
|
||||
result[name]['status']=status
|
||||
result[name]['description']=description
|
||||
return result
|
||||
|
||||
def print_models(self):
|
||||
'''
|
||||
Print a table of models, their descriptions, and load status
|
||||
'''
|
||||
models = self.list_models()
|
||||
for name in models:
|
||||
print(f'{name:20s} {models[name]["status"]:>10s} {models[name]["description"]}')
|
||||
|
||||
def _check_memory(self):
|
||||
free_memory = psutil.virtual_memory()[4]
|
||||
print(f'DEBUG: free memory = {free_memory}, min_mem = {self.min_free_mem}')
|
||||
while free_memory + AVG_MODEL_SIZE < self.min_free_mem:
|
||||
|
||||
print(f'DEBUG: free memory = {free_memory}')
|
||||
avail_memory = psutil.virtual_memory()[1]
|
||||
if avail_memory + AVG_MODEL_SIZE < self.min_avail_mem:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
if least_recent_model is None:
|
||||
return
|
||||
|
||||
print(f'DEBUG: clearing {least_recent_model} from cache (refcount = {getrefcount(self.models[least_recent_model]["model"])})')
|
||||
del self.models[least_recent_model]['model']
|
||||
gc.collect()
|
||||
|
||||
new_free_memory = psutil.virtual_memory()[4]
|
||||
if new_free_memory <= free_memory:
|
||||
print(f'>> **Unable to free memory for model caching.**')
|
||||
break;
|
||||
free_memory = new_free_memory
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
gc.collect()
|
||||
|
||||
|
||||
def _load_model(self, model_name:str):
|
||||
@ -106,18 +137,20 @@ class ModelCache(object):
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
print(f'>> Loading {model_name} weights from {weights}')
|
||||
print(f'>> Loading {model_name} from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
c = OmegaConf.load(config)
|
||||
with open(weights,'rb') as f:
|
||||
weight_bytes = f.read()
|
||||
self.model_hash = self._cached_sha256(weights,weight_bytes)
|
||||
model_hash = self._cached_sha256(weights,weight_bytes)
|
||||
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||
del weight_bytes
|
||||
sd = pl_sd['state_dict']
|
||||
@ -143,39 +176,40 @@ class ModelCache(object):
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
return model, width, height
|
||||
return model, width, height, model_hash
|
||||
|
||||
def unload_model(self, model_name:str):
|
||||
if model_name not in self.models:
|
||||
return
|
||||
print(f'>> Unloading model {model_name}')
|
||||
print(f'>> Caching model {model_name} in system RAM')
|
||||
model = self.models[model_name]['model']
|
||||
self._model_to_cpu(model)
|
||||
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||
gc.collect()
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _model_to_cpu(self,model):
|
||||
if self._has_cuda():
|
||||
print(f'DEBUG: moving model to cpu')
|
||||
model.first_stage_model.to('cpu')
|
||||
model.cond_stage_model.to('cpu')
|
||||
model.model.to('cpu')
|
||||
return model.to('cpu')
|
||||
|
||||
def _model_from_cpu(self,model):
|
||||
if self._has_cuda():
|
||||
print(f'DEBUG: moving model into {self.device.type}')
|
||||
model.to(self.device)
|
||||
model.first_stage_model.to(self.device)
|
||||
model.cond_stage_model.to(self.device)
|
||||
return model
|
||||
|
||||
def _pop_oldest_model(self):
|
||||
'''
|
||||
Remove the first element of the FIFO, which ought
|
||||
to be the least recently accessed model.
|
||||
to be the least recently accessed model. Do not
|
||||
pop the last one, because it is in active use!
|
||||
'''
|
||||
if len(self.stack)>0:
|
||||
self.stack.pop(0)
|
||||
if len(self.stack) > 1:
|
||||
return self.stack.pop(0)
|
||||
|
||||
def _push_newest_model(self,model_name:str):
|
||||
'''
|
||||
@ -187,7 +221,6 @@ class ModelCache(object):
|
||||
except ValueError:
|
||||
pass
|
||||
self.stack.append(model_name)
|
||||
print(f'DEBUG, stack={self.stack}')
|
||||
|
||||
def _has_cuda(self):
|
||||
return self.device.type == 'cuda'
|
||||
|
@ -47,7 +47,10 @@ COMMANDS = (
|
||||
'--skip_normalize','-x',
|
||||
'--log_tokenization','-t',
|
||||
'--hires_fix',
|
||||
'!fix','!fetch','!history','!search','!clear',
|
||||
'!fix','!fetch','!history','!search','!clear','!models','!switch',
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
'!switch',
|
||||
)
|
||||
IMG_PATH_COMMANDS = (
|
||||
'--outdir[=\s]',
|
||||
@ -63,8 +66,9 @@ IMG_FILE_COMMANDS=(
|
||||
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
|
||||
|
||||
class Completer(object):
|
||||
def __init__(self, options):
|
||||
def __init__(self, options, models=[]):
|
||||
self.options = sorted(options)
|
||||
self.models = sorted(models)
|
||||
self.seeds = set()
|
||||
self.matches = list()
|
||||
self.default_dir = None
|
||||
@ -88,6 +92,9 @@ class Completer(object):
|
||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||
self.matches= self._seed_completions(text,state)
|
||||
|
||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text,state)
|
||||
|
||||
# This is the first time for this text, so build a match list.
|
||||
elif text:
|
||||
self.matches = [
|
||||
@ -188,6 +195,21 @@ class Completer(object):
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _model_completions(self, text, state):
|
||||
m = re.search('(!switch\s+)(\w*)',text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
partial = m.groups()[1]
|
||||
else:
|
||||
switch = ''
|
||||
partial = text
|
||||
matches = list()
|
||||
for s in self.models:
|
||||
if s.startswith(partial):
|
||||
matches.append(switch+s)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
readline.insert_text(self.linebuffer)
|
||||
@ -266,9 +288,9 @@ class DummyCompleter(Completer):
|
||||
def set_line(self,line):
|
||||
print(f'# {line}')
|
||||
|
||||
def get_completer(opt:Args)->Completer:
|
||||
def get_completer(opt:Args, models=[])->Completer:
|
||||
if readline_available:
|
||||
completer = Completer(COMMANDS)
|
||||
completer = Completer(COMMANDS,models)
|
||||
|
||||
readline.set_completer(
|
||||
completer.complete
|
||||
|
@ -106,7 +106,7 @@ class DDPM(pl.LightningModule):
|
||||
], 'currently only supporting "eps" and "x0"'
|
||||
self.parameterization = parameterization
|
||||
print(
|
||||
f'{self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
|
||||
f' >> {self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
|
||||
)
|
||||
self.cond_stage_model = None
|
||||
self.clip_denoised = clip_denoised
|
||||
|
@ -245,7 +245,7 @@ class AttnBlock(nn.Module):
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla"):
|
||||
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
print(f" >> Making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "none":
|
||||
@ -521,7 +521,7 @@ class Decoder(nn.Module):
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(
|
||||
print(" >> Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
|
@ -75,7 +75,7 @@ def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(
|
||||
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
||||
f' >> {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
||||
)
|
||||
return total_params
|
||||
|
||||
|
@ -16,8 +16,6 @@ from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||
from ldm.invoke.image_util import make_grid
|
||||
from ldm.invoke.log import write_log
|
||||
from omegaconf import OmegaConf
|
||||
from backend.invoke_ai_web_server import InvokeAIWebServer
|
||||
|
||||
|
||||
def main():
|
||||
"""Initialize command-line parsers and the diffusion model"""
|
||||
@ -33,7 +31,7 @@ def main():
|
||||
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
|
||||
sys.exit(-1)
|
||||
|
||||
print('* Initializing, be patient...\n')
|
||||
print('* Initializing, be patient...')
|
||||
from ldm.generate import Generate
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
@ -42,45 +40,7 @@ def main():
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# Loading Face Restoration and ESRGAN Modules
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if opt.restore or opt.esrgan:
|
||||
from ldm.invoke.restoration import Restoration
|
||||
restoration = Restoration()
|
||||
if opt.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_dir, opt.gfpgan_model_path)
|
||||
else:
|
||||
print('>> Face restoration disabled')
|
||||
if opt.esrgan:
|
||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||
else:
|
||||
print('>> Upscaling disabled')
|
||||
else:
|
||||
print('>> Face restoration and upscaling disabled')
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
|
||||
# creating a simple text2image object with a handful of
|
||||
# defaults passed on the command line.
|
||||
# additional parameters will be added (or overriden) during
|
||||
# the user input loop
|
||||
try:
|
||||
gen = Generate(
|
||||
conf = opt.conf,
|
||||
model = opt.model,
|
||||
sampler_name = opt.sampler_name,
|
||||
embedding_path = opt.embedding_path,
|
||||
full_precision = opt.full_precision,
|
||||
precision = opt.precision,
|
||||
gfpgan=gfpgan,
|
||||
codeformer=codeformer,
|
||||
esrgan=esrgan,
|
||||
free_gpu_mem=opt.free_gpu_mem,
|
||||
)
|
||||
except (FileNotFoundError, IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
gfpgan,codeformer,esrgan = load_face_restoration(opt)
|
||||
|
||||
# make sure the output directory exists
|
||||
if not os.path.exists(opt.outdir):
|
||||
@ -100,6 +60,24 @@ def main():
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
# creating a Generate object:
|
||||
try:
|
||||
gen = Generate(
|
||||
conf = opt.conf,
|
||||
model = opt.model,
|
||||
sampler_name = opt.sampler_name,
|
||||
embedding_path = opt.embedding_path,
|
||||
full_precision = opt.full_precision,
|
||||
precision = opt.precision,
|
||||
gfpgan=gfpgan,
|
||||
codeformer=codeformer,
|
||||
esrgan=esrgan,
|
||||
free_gpu_mem=opt.free_gpu_mem,
|
||||
)
|
||||
except (FileNotFoundError, IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
if opt.seamless:
|
||||
print(">> changed to seamless tiling mode")
|
||||
|
||||
@ -124,12 +102,12 @@ def main_loop(gen, opt, infile):
|
||||
done = False
|
||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||
last_results = list()
|
||||
model_config = OmegaConf.load(opt.conf)[opt.model]
|
||||
model_config = OmegaConf.load(opt.conf)
|
||||
|
||||
# The readline completer reads history from the .dream_history file located in the
|
||||
# output directory specified at the time of script launch. We do not currently support
|
||||
# changing the history file midstream when the output directory is changed.
|
||||
completer = get_completer(opt)
|
||||
completer = get_completer(opt, models=list(model_config.keys()))
|
||||
output_cntr = completer.get_current_history_length()+1
|
||||
|
||||
# os.pathconf is not available on Windows
|
||||
@ -173,6 +151,16 @@ def main_loop(gen, opt, infile):
|
||||
command = command.replace('!fix ','',1)
|
||||
operation = 'postprocess'
|
||||
|
||||
elif subcommand.startswith('switch'):
|
||||
model_name = command.replace('!switch ','',1)
|
||||
gen.set_model(model_name)
|
||||
continue
|
||||
|
||||
elif subcommand.startswith('models'):
|
||||
model_name = command.replace('!models ','',1)
|
||||
gen.model_cache.print_models()
|
||||
continue
|
||||
|
||||
elif subcommand.startswith('fetch'):
|
||||
file_path = command.replace('!fetch ','',1)
|
||||
retrieve_dream_command(opt,file_path,completer)
|
||||
@ -218,9 +206,9 @@ def main_loop(gen, opt, infile):
|
||||
|
||||
# width and height are set by model if not specified
|
||||
if not opt.width:
|
||||
opt.width = model_config.width
|
||||
opt.width = gen.width
|
||||
if not opt.height:
|
||||
opt.height = model_config.height
|
||||
opt.height = gen.height
|
||||
|
||||
# retrieve previous value of init image if requested
|
||||
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
|
||||
@ -509,6 +497,7 @@ def get_next_command(infile=None) -> str: # command string
|
||||
|
||||
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
|
||||
print('\n* --web was specified, starting web server...')
|
||||
from backend.invoke_ai_web_server import InvokeAIWebServer
|
||||
# Change working directory to the stable-diffusion directory
|
||||
os.chdir(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
@ -547,6 +536,28 @@ def split_variations(variations_string) -> list:
|
||||
else:
|
||||
return parts
|
||||
|
||||
def load_face_restoration(opt):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if opt.restore or opt.esrgan:
|
||||
from ldm.invoke.restoration import Restoration
|
||||
restoration = Restoration()
|
||||
if opt.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_dir, opt.gfpgan_model_path)
|
||||
else:
|
||||
print('>> Face restoration disabled')
|
||||
if opt.esrgan:
|
||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||
else:
|
||||
print('>> Upscaling disabled')
|
||||
else:
|
||||
print('>> Face restoration and upscaling disabled')
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
return gfpgan,codeformer,esrgan
|
||||
|
||||
|
||||
def retrieve_dream_command(opt,file_path,completer):
|
||||
'''
|
||||
Given a full or partial path to a previously-generated image file,
|
||||
|
Loading…
Reference in New Issue
Block a user