mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
enable fast switching between models in invoke.py
- This PR enables two new commands in the invoke.py script !models -- list the available models and their cache status !switch <model> -- switch to the indicated model Example: invoke> !models laion400m not loaded Latent Diffusion LAION400M model stable-diffusion-1.4 active Stable Diffusion inference model version 1.4 waifu-1.3 cached Waifu anime model version 1.3 invoke> !switch waifu-1.3 >> Caching model stable-diffusion-1.4 in system RAM >> Retrieving model waifu-1.3 from system RAM cache The name and descriptions of the models are taken from `config/models.yaml`. A future enhancement to `model_cache.py` will be to enable new model stanzas to be added to the file programmatically. This will be useful for the WebGUI. More details: - Use fast switching algorithm described in PR #948 - Models are selected using their configuration stanza name given in models.yaml. - To avoid filling up CPU RAM with cached models, this PR implements an LRU cache that monitors available CPU RAM. - The caching code allows the minimum value of available RAM to be adjusted, but invoke.py does not currently have a command-line argument that allows you to set it. The minimum free RAM is arbitrarily set to 2 GB. - Add optional description field to configs/models.yaml Unrelated fixes: - Added ">>" to CompViz model loading messages in order to make user experience more consistent. - When generating an image greater than defaults, will only warn about possible VRAM filling the first time. - Fixed bug that was causing help message to be printed twice. This involved moving the import line for the web backend into the section where it is called. Coauthored by: @ArDiouscuros
This commit is contained in:
parent
b9e910b5f4
commit
488334710b
152
ldm/generate.py
152
ldm/generate.py
@ -33,6 +33,7 @@ from ldm.invoke.args import metadata_from_png
|
|||||||
from ldm.invoke.image_util import InitImageResizer
|
from ldm.invoke.image_util import InitImageResizer
|
||||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||||
from ldm.invoke.conditioning import get_uc_and_c
|
from ldm.invoke.conditioning import get_uc_and_c
|
||||||
|
from ldm.invoke.model_cache import ModelCache
|
||||||
|
|
||||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
|
|
||||||
@ -123,12 +124,11 @@ class Generate:
|
|||||||
esrgan=None,
|
esrgan=None,
|
||||||
free_gpu_mem=False,
|
free_gpu_mem=False,
|
||||||
):
|
):
|
||||||
models = OmegaConf.load(conf)
|
mconfig = OmegaConf.load(conf)
|
||||||
mconfig = models[model]
|
self.model_name = model
|
||||||
self.weights = mconfig.weights if weights is None else weights
|
self.height = None
|
||||||
self.config = mconfig.config if config is None else config
|
self.width = None
|
||||||
self.height = mconfig.height
|
self.model_cache = None
|
||||||
self.width = mconfig.width
|
|
||||||
self.iterations = 1
|
self.iterations = 1
|
||||||
self.steps = 50
|
self.steps = 50
|
||||||
self.cfg_scale = 7.5
|
self.cfg_scale = 7.5
|
||||||
@ -139,6 +139,7 @@ class Generate:
|
|||||||
self.seamless = False
|
self.seamless = False
|
||||||
self.embedding_path = embedding_path
|
self.embedding_path = embedding_path
|
||||||
self.model = None # empty for now
|
self.model = None # empty for now
|
||||||
|
self.model_hash = None
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
self.device = None
|
self.device = None
|
||||||
self.session_peakmem = None
|
self.session_peakmem = None
|
||||||
@ -149,6 +150,7 @@ class Generate:
|
|||||||
self.codeformer = codeformer
|
self.codeformer = codeformer
|
||||||
self.esrgan = esrgan
|
self.esrgan = esrgan
|
||||||
self.free_gpu_mem = free_gpu_mem
|
self.free_gpu_mem = free_gpu_mem
|
||||||
|
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||||
|
|
||||||
# Note that in previous versions, there was an option to pass the
|
# Note that in previous versions, there was an option to pass the
|
||||||
# device to Generate(). However the device was then ignored, so
|
# device to Generate(). However the device was then ignored, so
|
||||||
@ -164,6 +166,9 @@ class Generate:
|
|||||||
if self.precision == 'auto':
|
if self.precision == 'auto':
|
||||||
self.precision = choose_precision(self.device)
|
self.precision = choose_precision(self.device)
|
||||||
|
|
||||||
|
# model caching system for fast switching
|
||||||
|
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||||
|
|
||||||
# for VRAM usage statistics
|
# for VRAM usage statistics
|
||||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
@ -294,7 +299,12 @@ class Generate:
|
|||||||
with_variations = [] if with_variations is None else with_variations
|
with_variations = [] if with_variations is None else with_variations
|
||||||
|
|
||||||
# will instantiate the model or return it from cache
|
# will instantiate the model or return it from cache
|
||||||
model = self.load_model()
|
model = self.set_model(self.model_name)
|
||||||
|
|
||||||
|
# self.width and self.height are set by set_model()
|
||||||
|
# to the width and height of the image training set
|
||||||
|
width = width or self.width
|
||||||
|
height = height or self.height
|
||||||
|
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
@ -584,8 +594,9 @@ class Generate:
|
|||||||
# this returns a torch tensor
|
# this returns a torch tensor
|
||||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||||
|
|
||||||
if (image.width * image.height) > (self.width * self.height):
|
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||||
|
self.size_matters = False
|
||||||
|
|
||||||
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
||||||
|
|
||||||
@ -635,29 +646,47 @@ class Generate:
|
|||||||
return self.generators['inpaint']
|
return self.generators['inpaint']
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
'''
|
||||||
if self.model is None:
|
preload model identified in self.model_name
|
||||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
'''
|
||||||
try:
|
self.set_model(self.model_name)
|
||||||
model = self._load_model_from_config(self.config, self.weights)
|
|
||||||
if self.embedding_path is not None:
|
|
||||||
model.embedding_manager.load(
|
|
||||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
|
||||||
)
|
|
||||||
self.model = model.to(self.device)
|
|
||||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
|
||||||
self.model.cond_stage_model.device = self.device
|
|
||||||
except AttributeError as e:
|
|
||||||
print(f'>> Error loading model. {str(e)}', file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
raise SystemExit from e
|
|
||||||
|
|
||||||
self._set_sampler()
|
def set_model(self,model_name):
|
||||||
|
"""
|
||||||
|
Given the name of a model defined in models.yaml, will load and initialize it
|
||||||
|
and return the model object. Previously-used models will be cached.
|
||||||
|
"""
|
||||||
|
if self.model_name == model_name and self.model is not None:
|
||||||
|
return self.model
|
||||||
|
|
||||||
for m in self.model.modules():
|
model_data = self.model_cache.get_model(model_name)
|
||||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
if model_data is None or len(model_data) == 0:
|
||||||
m._orig_padding_mode = m.padding_mode
|
print(f'** Model switch failed **')
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
self.model = model_data['model']
|
||||||
|
self.width = model_data['width']
|
||||||
|
self.height= model_data['height']
|
||||||
|
self.model_hash = model_data['hash']
|
||||||
|
|
||||||
|
# uncache generators so they pick up new models
|
||||||
|
self.generators = {}
|
||||||
|
|
||||||
|
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||||
|
if self.embedding_path is not None:
|
||||||
|
model.embedding_manager.load(
|
||||||
|
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
||||||
|
)
|
||||||
|
|
||||||
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||||
|
self.model.cond_stage_model.device = self.device
|
||||||
|
self._set_sampler()
|
||||||
|
|
||||||
|
for m in self.model.modules():
|
||||||
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
|
m._orig_padding_mode = m.padding_mode
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def correct_colors(self,
|
def correct_colors(self,
|
||||||
@ -761,53 +790,6 @@ class Generate:
|
|||||||
|
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
# Be warned: config is the path to the model config file, not the invoke conf file!
|
|
||||||
# Also note that we can get config and weights from self, so why do we need to
|
|
||||||
# pass them as args?
|
|
||||||
def _load_model_from_config(self, config, weights):
|
|
||||||
print(f'>> Loading model from {weights}')
|
|
||||||
|
|
||||||
# for usage statistics
|
|
||||||
device_type = choose_torch_device()
|
|
||||||
if device_type == 'cuda':
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
tic = time.time()
|
|
||||||
|
|
||||||
# this does the work
|
|
||||||
c = OmegaConf.load(config)
|
|
||||||
with open(weights,'rb') as f:
|
|
||||||
weight_bytes = f.read()
|
|
||||||
self.model_hash = self._cached_sha256(weights,weight_bytes)
|
|
||||||
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
|
||||||
del weight_bytes
|
|
||||||
sd = pl_sd['state_dict']
|
|
||||||
model = instantiate_from_config(c.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
|
|
||||||
if self.precision == 'float16':
|
|
||||||
print('>> Using faster float16 precision')
|
|
||||||
model.to(torch.float16)
|
|
||||||
else:
|
|
||||||
print('>> Using more accurate float32 precision')
|
|
||||||
|
|
||||||
model.to(self.device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# usage statistics
|
|
||||||
toc = time.time()
|
|
||||||
print(
|
|
||||||
f'>> Model loaded in', '%4.2fs' % (toc - tic)
|
|
||||||
)
|
|
||||||
if self._has_cuda():
|
|
||||||
print(
|
|
||||||
'>> Max VRAM used to load the model:',
|
|
||||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
|
||||||
'\n>> Current VRAM usage:'
|
|
||||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _load_img(self, img, width, height)->Image:
|
def _load_img(self, img, width, height)->Image:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
image = img
|
image = img
|
||||||
@ -951,26 +933,6 @@ class Generate:
|
|||||||
def _has_cuda(self):
|
def _has_cuda(self):
|
||||||
return self.device.type == 'cuda'
|
return self.device.type == 'cuda'
|
||||||
|
|
||||||
def _cached_sha256(self,path,data):
|
|
||||||
dirname = os.path.dirname(path)
|
|
||||||
basename = os.path.basename(path)
|
|
||||||
base, _ = os.path.splitext(basename)
|
|
||||||
hashpath = os.path.join(dirname,base+'.sha256')
|
|
||||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
|
||||||
with open(hashpath) as f:
|
|
||||||
hash = f.read()
|
|
||||||
return hash
|
|
||||||
print(f'>> Calculating sha256 hash of weights file')
|
|
||||||
tic = time.time()
|
|
||||||
sha = hashlib.sha256()
|
|
||||||
sha.update(data)
|
|
||||||
hash = sha.hexdigest()
|
|
||||||
toc = time.time()
|
|
||||||
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
|
||||||
with open(hashpath,'w') as f:
|
|
||||||
f.write(hash)
|
|
||||||
return hash
|
|
||||||
|
|
||||||
def write_intermediate_images(self,modulus,path):
|
def write_intermediate_images(self,modulus,path):
|
||||||
counter = -1
|
counter = -1
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
|
@ -20,22 +20,35 @@ from ldm.util import instantiate_from_config
|
|||||||
|
|
||||||
GIGS=2**30
|
GIGS=2**30
|
||||||
AVG_MODEL_SIZE=2.1*GIGS
|
AVG_MODEL_SIZE=2.1*GIGS
|
||||||
|
DEFAULT_MIN_AVAIL=2*GIGS
|
||||||
|
|
||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_free_mem=2*GIGS):
|
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_avail_mem=DEFAULT_MIN_AVAIL):
|
||||||
|
'''
|
||||||
|
Initialize with the path to the models.yaml config file,
|
||||||
|
the torch device type, and precision. The optional
|
||||||
|
min_avail_mem argument specifies how much unused system
|
||||||
|
(CPU) memory to preserve. The cache of models in RAM will
|
||||||
|
grow until this value is approached. Default is 2G.
|
||||||
|
'''
|
||||||
# prevent nasty-looking CLIP log message
|
# prevent nasty-looking CLIP log message
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.device = torch.device(device_type)
|
self.device = torch.device(device_type)
|
||||||
self.min_free_mem = min_free_mem
|
self.min_avail_mem = min_avail_mem
|
||||||
self.models = {}
|
self.models = {}
|
||||||
self.stack = [] # this is an LRU FIFO
|
self.stack = [] # this is an LRU FIFO
|
||||||
self.current_model = None
|
self.current_model = None
|
||||||
|
|
||||||
def get_model(self, model_name:str):
|
def get_model(self, model_name:str):
|
||||||
|
'''
|
||||||
|
Given a model named identified in models.yaml, return
|
||||||
|
the model object. If in RAM will load into GPU VRAM.
|
||||||
|
If on disk, will load from there.
|
||||||
|
'''
|
||||||
if model_name not in self.config:
|
if model_name not in self.config:
|
||||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.current_model != model_name:
|
if self.current_model != model_name:
|
||||||
@ -43,22 +56,42 @@ class ModelCache(object):
|
|||||||
|
|
||||||
if model_name in self.models:
|
if model_name in self.models:
|
||||||
requested_model = self.models[model_name]['model']
|
requested_model = self.models[model_name]['model']
|
||||||
self._model_from_cpu(requested_model)
|
print(f'>> Retrieving model {model_name} from system RAM cache')
|
||||||
|
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
|
||||||
width = self.models[model_name]['width']
|
width = self.models[model_name]['width']
|
||||||
height = self.models[model_name]['height']
|
height = self.models[model_name]['height']
|
||||||
|
hash = self.models[model_name]['hash']
|
||||||
else:
|
else:
|
||||||
self._check_memory()
|
self._check_memory()
|
||||||
requested_model, width, height = self._load_model(model_name)
|
try:
|
||||||
self.models[model_name] = {}
|
requested_model, width, height, hash = self._load_model(model_name)
|
||||||
self.models[model_name]['model'] = requested_model
|
self.models[model_name] = {}
|
||||||
self.models[model_name]['width'] = width
|
self.models[model_name]['model'] = requested_model
|
||||||
self.models[model_name]['height'] = height
|
self.models[model_name]['width'] = width
|
||||||
|
self.models[model_name]['height'] = height
|
||||||
|
self.models[model_name]['hash'] = hash
|
||||||
|
except Exception as e:
|
||||||
|
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||||
|
return {}
|
||||||
|
|
||||||
self.current_model = model_name
|
self.current_model = model_name
|
||||||
self._push_newest_model(model_name)
|
self._push_newest_model(model_name)
|
||||||
return requested_model, width, height
|
return {
|
||||||
|
'model':requested_model,
|
||||||
|
'width':width,
|
||||||
|
'height':height,
|
||||||
|
'hash': hash
|
||||||
|
}
|
||||||
|
|
||||||
def list_models(self):
|
def list_models(self) -> dict:
|
||||||
|
'''
|
||||||
|
Return a dict of models in the format:
|
||||||
|
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
||||||
|
'description': description,
|
||||||
|
},
|
||||||
|
model_name2: { etc }
|
||||||
|
'''
|
||||||
|
result = {}
|
||||||
for name in self.config:
|
for name in self.config:
|
||||||
try:
|
try:
|
||||||
description = self.config[name].description
|
description = self.config[name].description
|
||||||
@ -70,28 +103,26 @@ class ModelCache(object):
|
|||||||
status = 'cached'
|
status = 'cached'
|
||||||
else:
|
else:
|
||||||
status = 'not loaded'
|
status = 'not loaded'
|
||||||
print(f'{name:20s} {status:>10s} {description}')
|
result[name]={}
|
||||||
|
result[name]['status']=status
|
||||||
|
result[name]['description']=description
|
||||||
|
return result
|
||||||
|
|
||||||
|
def print_models(self):
|
||||||
|
'''
|
||||||
|
Print a table of models, their descriptions, and load status
|
||||||
|
'''
|
||||||
|
models = self.list_models()
|
||||||
|
for name in models:
|
||||||
|
print(f'{name:20s} {models[name]["status"]:>10s} {models[name]["description"]}')
|
||||||
|
|
||||||
def _check_memory(self):
|
def _check_memory(self):
|
||||||
free_memory = psutil.virtual_memory()[4]
|
avail_memory = psutil.virtual_memory()[1]
|
||||||
print(f'DEBUG: free memory = {free_memory}, min_mem = {self.min_free_mem}')
|
if avail_memory + AVG_MODEL_SIZE < self.min_avail_mem:
|
||||||
while free_memory + AVG_MODEL_SIZE < self.min_free_mem:
|
|
||||||
|
|
||||||
print(f'DEBUG: free memory = {free_memory}')
|
|
||||||
least_recent_model = self._pop_oldest_model()
|
least_recent_model = self._pop_oldest_model()
|
||||||
if least_recent_model is None:
|
if least_recent_model is not None:
|
||||||
return
|
del self.models[least_recent_model]
|
||||||
|
gc.collect()
|
||||||
print(f'DEBUG: clearing {least_recent_model} from cache (refcount = {getrefcount(self.models[least_recent_model]["model"])})')
|
|
||||||
del self.models[least_recent_model]['model']
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
new_free_memory = psutil.virtual_memory()[4]
|
|
||||||
if new_free_memory <= free_memory:
|
|
||||||
print(f'>> **Unable to free memory for model caching.**')
|
|
||||||
break;
|
|
||||||
free_memory = new_free_memory
|
|
||||||
|
|
||||||
|
|
||||||
def _load_model(self, model_name:str):
|
def _load_model(self, model_name:str):
|
||||||
@ -106,18 +137,20 @@ class ModelCache(object):
|
|||||||
width = mconfig.width
|
width = mconfig.width
|
||||||
height = mconfig.height
|
height = mconfig.height
|
||||||
|
|
||||||
print(f'>> Loading {model_name} weights from {weights}')
|
print(f'>> Loading {model_name} from {weights}')
|
||||||
|
|
||||||
# for usage statistics
|
# for usage statistics
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
# this does the work
|
# this does the work
|
||||||
c = OmegaConf.load(config)
|
c = OmegaConf.load(config)
|
||||||
with open(weights,'rb') as f:
|
with open(weights,'rb') as f:
|
||||||
weight_bytes = f.read()
|
weight_bytes = f.read()
|
||||||
self.model_hash = self._cached_sha256(weights,weight_bytes)
|
model_hash = self._cached_sha256(weights,weight_bytes)
|
||||||
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||||
del weight_bytes
|
del weight_bytes
|
||||||
sd = pl_sd['state_dict']
|
sd = pl_sd['state_dict']
|
||||||
@ -143,39 +176,40 @@ class ModelCache(object):
|
|||||||
'\n>> Current VRAM usage:'
|
'\n>> Current VRAM usage:'
|
||||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||||
)
|
)
|
||||||
return model, width, height
|
return model, width, height, model_hash
|
||||||
|
|
||||||
def unload_model(self, model_name:str):
|
def unload_model(self, model_name:str):
|
||||||
if model_name not in self.models:
|
if model_name not in self.models:
|
||||||
return
|
return
|
||||||
print(f'>> Unloading model {model_name}')
|
print(f'>> Caching model {model_name} in system RAM')
|
||||||
model = self.models[model_name]['model']
|
model = self.models[model_name]['model']
|
||||||
self._model_to_cpu(model)
|
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def _model_to_cpu(self,model):
|
def _model_to_cpu(self,model):
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
print(f'DEBUG: moving model to cpu')
|
|
||||||
model.first_stage_model.to('cpu')
|
model.first_stage_model.to('cpu')
|
||||||
model.cond_stage_model.to('cpu')
|
model.cond_stage_model.to('cpu')
|
||||||
model.model.to('cpu')
|
model.model.to('cpu')
|
||||||
|
return model.to('cpu')
|
||||||
|
|
||||||
def _model_from_cpu(self,model):
|
def _model_from_cpu(self,model):
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
print(f'DEBUG: moving model into {self.device.type}')
|
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
model.first_stage_model.to(self.device)
|
model.first_stage_model.to(self.device)
|
||||||
model.cond_stage_model.to(self.device)
|
model.cond_stage_model.to(self.device)
|
||||||
|
return model
|
||||||
|
|
||||||
def _pop_oldest_model(self):
|
def _pop_oldest_model(self):
|
||||||
'''
|
'''
|
||||||
Remove the first element of the FIFO, which ought
|
Remove the first element of the FIFO, which ought
|
||||||
to be the least recently accessed model.
|
to be the least recently accessed model. Do not
|
||||||
|
pop the last one, because it is in active use!
|
||||||
'''
|
'''
|
||||||
if len(self.stack)>0:
|
if len(self.stack) > 1:
|
||||||
self.stack.pop(0)
|
return self.stack.pop(0)
|
||||||
|
|
||||||
def _push_newest_model(self,model_name:str):
|
def _push_newest_model(self,model_name:str):
|
||||||
'''
|
'''
|
||||||
@ -187,7 +221,6 @@ class ModelCache(object):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
self.stack.append(model_name)
|
self.stack.append(model_name)
|
||||||
print(f'DEBUG, stack={self.stack}')
|
|
||||||
|
|
||||||
def _has_cuda(self):
|
def _has_cuda(self):
|
||||||
return self.device.type == 'cuda'
|
return self.device.type == 'cuda'
|
||||||
|
@ -47,7 +47,10 @@ COMMANDS = (
|
|||||||
'--skip_normalize','-x',
|
'--skip_normalize','-x',
|
||||||
'--log_tokenization','-t',
|
'--log_tokenization','-t',
|
||||||
'--hires_fix',
|
'--hires_fix',
|
||||||
'!fix','!fetch','!history','!search','!clear',
|
'!fix','!fetch','!history','!search','!clear','!models','!switch',
|
||||||
|
)
|
||||||
|
MODEL_COMMANDS = (
|
||||||
|
'!switch',
|
||||||
)
|
)
|
||||||
IMG_PATH_COMMANDS = (
|
IMG_PATH_COMMANDS = (
|
||||||
'--outdir[=\s]',
|
'--outdir[=\s]',
|
||||||
@ -63,8 +66,9 @@ IMG_FILE_COMMANDS=(
|
|||||||
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
|
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
|
||||||
|
|
||||||
class Completer(object):
|
class Completer(object):
|
||||||
def __init__(self, options):
|
def __init__(self, options, models=[]):
|
||||||
self.options = sorted(options)
|
self.options = sorted(options)
|
||||||
|
self.models = sorted(models)
|
||||||
self.seeds = set()
|
self.seeds = set()
|
||||||
self.matches = list()
|
self.matches = list()
|
||||||
self.default_dir = None
|
self.default_dir = None
|
||||||
@ -88,6 +92,9 @@ class Completer(object):
|
|||||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||||
self.matches= self._seed_completions(text,state)
|
self.matches= self._seed_completions(text,state)
|
||||||
|
|
||||||
|
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||||
|
self.matches= self._model_completions(text,state)
|
||||||
|
|
||||||
# This is the first time for this text, so build a match list.
|
# This is the first time for this text, so build a match list.
|
||||||
elif text:
|
elif text:
|
||||||
self.matches = [
|
self.matches = [
|
||||||
@ -188,6 +195,21 @@ class Completer(object):
|
|||||||
matches.sort()
|
matches.sort()
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
def _model_completions(self, text, state):
|
||||||
|
m = re.search('(!switch\s+)(\w*)',text)
|
||||||
|
if m:
|
||||||
|
switch = m.groups()[0]
|
||||||
|
partial = m.groups()[1]
|
||||||
|
else:
|
||||||
|
switch = ''
|
||||||
|
partial = text
|
||||||
|
matches = list()
|
||||||
|
for s in self.models:
|
||||||
|
if s.startswith(partial):
|
||||||
|
matches.append(switch+s)
|
||||||
|
matches.sort()
|
||||||
|
return matches
|
||||||
|
|
||||||
def _pre_input_hook(self):
|
def _pre_input_hook(self):
|
||||||
if self.linebuffer:
|
if self.linebuffer:
|
||||||
readline.insert_text(self.linebuffer)
|
readline.insert_text(self.linebuffer)
|
||||||
@ -266,9 +288,9 @@ class DummyCompleter(Completer):
|
|||||||
def set_line(self,line):
|
def set_line(self,line):
|
||||||
print(f'# {line}')
|
print(f'# {line}')
|
||||||
|
|
||||||
def get_completer(opt:Args)->Completer:
|
def get_completer(opt:Args, models=[])->Completer:
|
||||||
if readline_available:
|
if readline_available:
|
||||||
completer = Completer(COMMANDS)
|
completer = Completer(COMMANDS,models)
|
||||||
|
|
||||||
readline.set_completer(
|
readline.set_completer(
|
||||||
completer.complete
|
completer.complete
|
||||||
|
@ -106,7 +106,7 @@ class DDPM(pl.LightningModule):
|
|||||||
], 'currently only supporting "eps" and "x0"'
|
], 'currently only supporting "eps" and "x0"'
|
||||||
self.parameterization = parameterization
|
self.parameterization = parameterization
|
||||||
print(
|
print(
|
||||||
f'{self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
|
f' >> {self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
|
||||||
)
|
)
|
||||||
self.cond_stage_model = None
|
self.cond_stage_model = None
|
||||||
self.clip_denoised = clip_denoised
|
self.clip_denoised = clip_denoised
|
||||||
|
@ -245,7 +245,7 @@ class AttnBlock(nn.Module):
|
|||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla"):
|
def make_attn(in_channels, attn_type="vanilla"):
|
||||||
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
print(f" >> Making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||||
if attn_type == "vanilla":
|
if attn_type == "vanilla":
|
||||||
return AttnBlock(in_channels)
|
return AttnBlock(in_channels)
|
||||||
elif attn_type == "none":
|
elif attn_type == "none":
|
||||||
@ -521,7 +521,7 @@ class Decoder(nn.Module):
|
|||||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||||
print("Working with z of shape {} = {} dimensions.".format(
|
print(" >> Working with z of shape {} = {} dimensions.".format(
|
||||||
self.z_shape, np.prod(self.z_shape)))
|
self.z_shape, np.prod(self.z_shape)))
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
|
@ -75,7 +75,7 @@ def count_params(model, verbose=False):
|
|||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(
|
||||||
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
f' >> {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
||||||
)
|
)
|
||||||
return total_params
|
return total_params
|
||||||
|
|
||||||
|
@ -16,8 +16,6 @@ from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
|||||||
from ldm.invoke.image_util import make_grid
|
from ldm.invoke.image_util import make_grid
|
||||||
from ldm.invoke.log import write_log
|
from ldm.invoke.log import write_log
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from backend.invoke_ai_web_server import InvokeAIWebServer
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Initialize command-line parsers and the diffusion model"""
|
"""Initialize command-line parsers and the diffusion model"""
|
||||||
@ -33,7 +31,7 @@ def main():
|
|||||||
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
|
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
print('* Initializing, be patient...\n')
|
print('* Initializing, be patient...')
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
# these two lines prevent a horrible warning message from appearing
|
||||||
@ -42,45 +40,7 @@ def main():
|
|||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
# Loading Face Restoration and ESRGAN Modules
|
# Loading Face Restoration and ESRGAN Modules
|
||||||
try:
|
gfpgan,codeformer,esrgan = load_face_restoration(opt)
|
||||||
gfpgan, codeformer, esrgan = None, None, None
|
|
||||||
if opt.restore or opt.esrgan:
|
|
||||||
from ldm.invoke.restoration import Restoration
|
|
||||||
restoration = Restoration()
|
|
||||||
if opt.restore:
|
|
||||||
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_dir, opt.gfpgan_model_path)
|
|
||||||
else:
|
|
||||||
print('>> Face restoration disabled')
|
|
||||||
if opt.esrgan:
|
|
||||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
|
||||||
else:
|
|
||||||
print('>> Upscaling disabled')
|
|
||||||
else:
|
|
||||||
print('>> Face restoration and upscaling disabled')
|
|
||||||
except (ModuleNotFoundError, ImportError):
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
|
||||||
|
|
||||||
# creating a simple text2image object with a handful of
|
|
||||||
# defaults passed on the command line.
|
|
||||||
# additional parameters will be added (or overriden) during
|
|
||||||
# the user input loop
|
|
||||||
try:
|
|
||||||
gen = Generate(
|
|
||||||
conf = opt.conf,
|
|
||||||
model = opt.model,
|
|
||||||
sampler_name = opt.sampler_name,
|
|
||||||
embedding_path = opt.embedding_path,
|
|
||||||
full_precision = opt.full_precision,
|
|
||||||
precision = opt.precision,
|
|
||||||
gfpgan=gfpgan,
|
|
||||||
codeformer=codeformer,
|
|
||||||
esrgan=esrgan,
|
|
||||||
free_gpu_mem=opt.free_gpu_mem,
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, IOError, KeyError) as e:
|
|
||||||
print(f'{e}. Aborting.')
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
# make sure the output directory exists
|
# make sure the output directory exists
|
||||||
if not os.path.exists(opt.outdir):
|
if not os.path.exists(opt.outdir):
|
||||||
@ -100,6 +60,24 @@ def main():
|
|||||||
print(f'{e}. Aborting.')
|
print(f'{e}. Aborting.')
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
|
# creating a Generate object:
|
||||||
|
try:
|
||||||
|
gen = Generate(
|
||||||
|
conf = opt.conf,
|
||||||
|
model = opt.model,
|
||||||
|
sampler_name = opt.sampler_name,
|
||||||
|
embedding_path = opt.embedding_path,
|
||||||
|
full_precision = opt.full_precision,
|
||||||
|
precision = opt.precision,
|
||||||
|
gfpgan=gfpgan,
|
||||||
|
codeformer=codeformer,
|
||||||
|
esrgan=esrgan,
|
||||||
|
free_gpu_mem=opt.free_gpu_mem,
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, IOError, KeyError) as e:
|
||||||
|
print(f'{e}. Aborting.')
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
if opt.seamless:
|
if opt.seamless:
|
||||||
print(">> changed to seamless tiling mode")
|
print(">> changed to seamless tiling mode")
|
||||||
|
|
||||||
@ -124,12 +102,12 @@ def main_loop(gen, opt, infile):
|
|||||||
done = False
|
done = False
|
||||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||||
last_results = list()
|
last_results = list()
|
||||||
model_config = OmegaConf.load(opt.conf)[opt.model]
|
model_config = OmegaConf.load(opt.conf)
|
||||||
|
|
||||||
# The readline completer reads history from the .dream_history file located in the
|
# The readline completer reads history from the .dream_history file located in the
|
||||||
# output directory specified at the time of script launch. We do not currently support
|
# output directory specified at the time of script launch. We do not currently support
|
||||||
# changing the history file midstream when the output directory is changed.
|
# changing the history file midstream when the output directory is changed.
|
||||||
completer = get_completer(opt)
|
completer = get_completer(opt, models=list(model_config.keys()))
|
||||||
output_cntr = completer.get_current_history_length()+1
|
output_cntr = completer.get_current_history_length()+1
|
||||||
|
|
||||||
# os.pathconf is not available on Windows
|
# os.pathconf is not available on Windows
|
||||||
@ -173,6 +151,16 @@ def main_loop(gen, opt, infile):
|
|||||||
command = command.replace('!fix ','',1)
|
command = command.replace('!fix ','',1)
|
||||||
operation = 'postprocess'
|
operation = 'postprocess'
|
||||||
|
|
||||||
|
elif subcommand.startswith('switch'):
|
||||||
|
model_name = command.replace('!switch ','',1)
|
||||||
|
gen.set_model(model_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif subcommand.startswith('models'):
|
||||||
|
model_name = command.replace('!models ','',1)
|
||||||
|
gen.model_cache.print_models()
|
||||||
|
continue
|
||||||
|
|
||||||
elif subcommand.startswith('fetch'):
|
elif subcommand.startswith('fetch'):
|
||||||
file_path = command.replace('!fetch ','',1)
|
file_path = command.replace('!fetch ','',1)
|
||||||
retrieve_dream_command(opt,file_path,completer)
|
retrieve_dream_command(opt,file_path,completer)
|
||||||
@ -218,9 +206,9 @@ def main_loop(gen, opt, infile):
|
|||||||
|
|
||||||
# width and height are set by model if not specified
|
# width and height are set by model if not specified
|
||||||
if not opt.width:
|
if not opt.width:
|
||||||
opt.width = model_config.width
|
opt.width = gen.width
|
||||||
if not opt.height:
|
if not opt.height:
|
||||||
opt.height = model_config.height
|
opt.height = gen.height
|
||||||
|
|
||||||
# retrieve previous value of init image if requested
|
# retrieve previous value of init image if requested
|
||||||
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
|
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
|
||||||
@ -509,6 +497,7 @@ def get_next_command(infile=None) -> str: # command string
|
|||||||
|
|
||||||
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
|
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
|
||||||
print('\n* --web was specified, starting web server...')
|
print('\n* --web was specified, starting web server...')
|
||||||
|
from backend.invoke_ai_web_server import InvokeAIWebServer
|
||||||
# Change working directory to the stable-diffusion directory
|
# Change working directory to the stable-diffusion directory
|
||||||
os.chdir(
|
os.chdir(
|
||||||
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
@ -547,6 +536,28 @@ def split_variations(variations_string) -> list:
|
|||||||
else:
|
else:
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
def load_face_restoration(opt):
|
||||||
|
try:
|
||||||
|
gfpgan, codeformer, esrgan = None, None, None
|
||||||
|
if opt.restore or opt.esrgan:
|
||||||
|
from ldm.invoke.restoration import Restoration
|
||||||
|
restoration = Restoration()
|
||||||
|
if opt.restore:
|
||||||
|
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_dir, opt.gfpgan_model_path)
|
||||||
|
else:
|
||||||
|
print('>> Face restoration disabled')
|
||||||
|
if opt.esrgan:
|
||||||
|
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||||
|
else:
|
||||||
|
print('>> Upscaling disabled')
|
||||||
|
else:
|
||||||
|
print('>> Face restoration and upscaling disabled')
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||||
|
return gfpgan,codeformer,esrgan
|
||||||
|
|
||||||
|
|
||||||
def retrieve_dream_command(opt,file_path,completer):
|
def retrieve_dream_command(opt,file_path,completer):
|
||||||
'''
|
'''
|
||||||
Given a full or partial path to a previously-generated image file,
|
Given a full or partial path to a previously-generated image file,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user