mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
enable fast switching between models in invoke.py
- This PR enables two new commands in the invoke.py script !models -- list the available models and their cache status !switch <model> -- switch to the indicated model Example: invoke> !models laion400m not loaded Latent Diffusion LAION400M model stable-diffusion-1.4 active Stable Diffusion inference model version 1.4 waifu-1.3 cached Waifu anime model version 1.3 invoke> !switch waifu-1.3 >> Caching model stable-diffusion-1.4 in system RAM >> Retrieving model waifu-1.3 from system RAM cache More details: - Use fast switching algorithm described in PR #948 - Models are selected using their configuration stanza name given in models.yaml. - To avoid filling up CPU RAM with cached models, this PR implements an LRU cache that monitors available CPU RAM. - The caching code allows the minimum value of available RAM to be adjusted, but invoke.py does not currently have a command-line argument that allows you to set it. The minimum free RAM is arbitrarily set to 2 GB. - Add optional description field to configs/models.yaml Unrelated fixes: - Added ">>" to CompViz model loading messages in order to make user experience more consistent. - When generating an image greater than defaults, will only warn about possible VRAM filling the first time. - Fixed bug that was causing help message to be printed twice. This involved moving the import line for the web backend into the section where it is called.
This commit is contained in:
parent
b9e910b5f4
commit
19341e95a6
152
ldm/generate.py
152
ldm/generate.py
@ -33,6 +33,7 @@ from ldm.invoke.args import metadata_from_png
|
|||||||
from ldm.invoke.image_util import InitImageResizer
|
from ldm.invoke.image_util import InitImageResizer
|
||||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||||
from ldm.invoke.conditioning import get_uc_and_c
|
from ldm.invoke.conditioning import get_uc_and_c
|
||||||
|
from ldm.invoke.model_cache import ModelCache
|
||||||
|
|
||||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
|
|
||||||
@ -123,12 +124,11 @@ class Generate:
|
|||||||
esrgan=None,
|
esrgan=None,
|
||||||
free_gpu_mem=False,
|
free_gpu_mem=False,
|
||||||
):
|
):
|
||||||
models = OmegaConf.load(conf)
|
mconfig = OmegaConf.load(conf)
|
||||||
mconfig = models[model]
|
self.model_name = model
|
||||||
self.weights = mconfig.weights if weights is None else weights
|
self.height = None
|
||||||
self.config = mconfig.config if config is None else config
|
self.width = None
|
||||||
self.height = mconfig.height
|
self.model_cache = None
|
||||||
self.width = mconfig.width
|
|
||||||
self.iterations = 1
|
self.iterations = 1
|
||||||
self.steps = 50
|
self.steps = 50
|
||||||
self.cfg_scale = 7.5
|
self.cfg_scale = 7.5
|
||||||
@ -139,6 +139,7 @@ class Generate:
|
|||||||
self.seamless = False
|
self.seamless = False
|
||||||
self.embedding_path = embedding_path
|
self.embedding_path = embedding_path
|
||||||
self.model = None # empty for now
|
self.model = None # empty for now
|
||||||
|
self.model_hash = None
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
self.device = None
|
self.device = None
|
||||||
self.session_peakmem = None
|
self.session_peakmem = None
|
||||||
@ -149,6 +150,7 @@ class Generate:
|
|||||||
self.codeformer = codeformer
|
self.codeformer = codeformer
|
||||||
self.esrgan = esrgan
|
self.esrgan = esrgan
|
||||||
self.free_gpu_mem = free_gpu_mem
|
self.free_gpu_mem = free_gpu_mem
|
||||||
|
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||||
|
|
||||||
# Note that in previous versions, there was an option to pass the
|
# Note that in previous versions, there was an option to pass the
|
||||||
# device to Generate(). However the device was then ignored, so
|
# device to Generate(). However the device was then ignored, so
|
||||||
@ -164,6 +166,9 @@ class Generate:
|
|||||||
if self.precision == 'auto':
|
if self.precision == 'auto':
|
||||||
self.precision = choose_precision(self.device)
|
self.precision = choose_precision(self.device)
|
||||||
|
|
||||||
|
# model caching system for fast switching
|
||||||
|
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||||
|
|
||||||
# for VRAM usage statistics
|
# for VRAM usage statistics
|
||||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
@ -294,7 +299,12 @@ class Generate:
|
|||||||
with_variations = [] if with_variations is None else with_variations
|
with_variations = [] if with_variations is None else with_variations
|
||||||
|
|
||||||
# will instantiate the model or return it from cache
|
# will instantiate the model or return it from cache
|
||||||
model = self.load_model()
|
model = self.set_model(self.model_name)
|
||||||
|
|
||||||
|
# self.width and self.height are set by set_model()
|
||||||
|
# to the width and height of the image training set
|
||||||
|
width = width or self.width
|
||||||
|
height = height or self.height
|
||||||
|
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
@ -584,8 +594,9 @@ class Generate:
|
|||||||
# this returns a torch tensor
|
# this returns a torch tensor
|
||||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||||
|
|
||||||
if (image.width * image.height) > (self.width * self.height):
|
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||||
|
self.size_matters = False
|
||||||
|
|
||||||
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
||||||
|
|
||||||
@ -635,29 +646,47 @@ class Generate:
|
|||||||
return self.generators['inpaint']
|
return self.generators['inpaint']
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
'''
|
||||||
if self.model is None:
|
preload model identified in self.model_name
|
||||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
'''
|
||||||
try:
|
self.set_model(self.model_name)
|
||||||
model = self._load_model_from_config(self.config, self.weights)
|
|
||||||
if self.embedding_path is not None:
|
|
||||||
model.embedding_manager.load(
|
|
||||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
|
||||||
)
|
|
||||||
self.model = model.to(self.device)
|
|
||||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
|
||||||
self.model.cond_stage_model.device = self.device
|
|
||||||
except AttributeError as e:
|
|
||||||
print(f'>> Error loading model. {str(e)}', file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
raise SystemExit from e
|
|
||||||
|
|
||||||
self._set_sampler()
|
def set_model(self,model_name):
|
||||||
|
"""
|
||||||
|
Given the name of a model defined in models.yaml, will load and initialize it
|
||||||
|
and return the model object. Previously-used models will be cached.
|
||||||
|
"""
|
||||||
|
if self.model_name == model_name and self.model is not None:
|
||||||
|
return self.model
|
||||||
|
|
||||||
for m in self.model.modules():
|
model_data = self.model_cache.get_model(model_name)
|
||||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
if model_data is None or len(model_data) == 0:
|
||||||
m._orig_padding_mode = m.padding_mode
|
print(f'** Model switch failed **')
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
self.model = model_data['model']
|
||||||
|
self.width = model_data['width']
|
||||||
|
self.height= model_data['height']
|
||||||
|
self.model_hash = model_data['hash']
|
||||||
|
|
||||||
|
# uncache generators so they pick up new models
|
||||||
|
self.generators = {}
|
||||||
|
|
||||||
|
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||||
|
if self.embedding_path is not None:
|
||||||
|
model.embedding_manager.load(
|
||||||
|
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
||||||
|
)
|
||||||
|
|
||||||
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||||
|
self.model.cond_stage_model.device = self.device
|
||||||
|
self._set_sampler()
|
||||||
|
|
||||||
|
for m in self.model.modules():
|
||||||
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
|
m._orig_padding_mode = m.padding_mode
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def correct_colors(self,
|
def correct_colors(self,
|
||||||
@ -761,53 +790,6 @@ class Generate:
|
|||||||
|
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
# Be warned: config is the path to the model config file, not the invoke conf file!
|
|
||||||
# Also note that we can get config and weights from self, so why do we need to
|
|
||||||
# pass them as args?
|
|
||||||
def _load_model_from_config(self, config, weights):
|
|
||||||
print(f'>> Loading model from {weights}')
|
|
||||||
|
|
||||||
# for usage statistics
|
|
||||||
device_type = choose_torch_device()
|
|
||||||
if device_type == 'cuda':
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
tic = time.time()
|
|
||||||
|
|
||||||
# this does the work
|
|
||||||
c = OmegaConf.load(config)
|
|
||||||
with open(weights,'rb') as f:
|
|
||||||
weight_bytes = f.read()
|
|
||||||
self.model_hash = self._cached_sha256(weights,weight_bytes)
|
|
||||||
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
|
||||||
del weight_bytes
|
|
||||||
sd = pl_sd['state_dict']
|
|
||||||
model = instantiate_from_config(c.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
|
|
||||||
if self.precision == 'float16':
|
|
||||||
print('>> Using faster float16 precision')
|
|
||||||
model.to(torch.float16)
|
|
||||||
else:
|
|
||||||
print('>> Using more accurate float32 precision')
|
|
||||||
|
|
||||||
model.to(self.device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# usage statistics
|
|
||||||
toc = time.time()
|
|
||||||
print(
|
|
||||||
f'>> Model loaded in', '%4.2fs' % (toc - tic)
|
|
||||||
)
|
|
||||||
if self._has_cuda():
|
|
||||||
print(
|
|
||||||
'>> Max VRAM used to load the model:',
|
|
||||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
|
||||||
'\n>> Current VRAM usage:'
|
|
||||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _load_img(self, img, width, height)->Image:
|
def _load_img(self, img, width, height)->Image:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
image = img
|
image = img
|
||||||
@ -951,26 +933,6 @@ class Generate:
|
|||||||
def _has_cuda(self):
|
def _has_cuda(self):
|
||||||
return self.device.type == 'cuda'
|
return self.device.type == 'cuda'
|
||||||
|
|
||||||
def _cached_sha256(self,path,data):
|
|
||||||
dirname = os.path.dirname(path)
|
|
||||||
basename = os.path.basename(path)
|
|
||||||
base, _ = os.path.splitext(basename)
|
|
||||||
hashpath = os.path.join(dirname,base+'.sha256')
|
|
||||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
|
||||||
with open(hashpath) as f:
|
|
||||||
hash = f.read()
|
|
||||||
return hash
|
|
||||||
print(f'>> Calculating sha256 hash of weights file')
|
|
||||||
tic = time.time()
|
|
||||||
sha = hashlib.sha256()
|
|
||||||
sha.update(data)
|
|
||||||
hash = sha.hexdigest()
|
|
||||||
toc = time.time()
|
|
||||||
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
|
||||||
with open(hashpath,'w') as f:
|
|
||||||
f.write(hash)
|
|
||||||
return hash
|
|
||||||
|
|
||||||
def write_intermediate_images(self,modulus,path):
|
def write_intermediate_images(self,modulus,path):
|
||||||
counter = -1
|
counter = -1
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
|
@ -20,22 +20,35 @@ from ldm.util import instantiate_from_config
|
|||||||
|
|
||||||
GIGS=2**30
|
GIGS=2**30
|
||||||
AVG_MODEL_SIZE=2.1*GIGS
|
AVG_MODEL_SIZE=2.1*GIGS
|
||||||
|
DEFAULT_MIN_AVAIL=2*GIGS
|
||||||
|
|
||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_free_mem=2*GIGS):
|
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_avail_mem=DEFAULT_MIN_AVAIL):
|
||||||
|
'''
|
||||||
|
Initialize with the path to the models.yaml config file,
|
||||||
|
the torch device type, and precision. The optional
|
||||||
|
min_avail_mem argument specifies how much unused system
|
||||||
|
(CPU) memory to preserve. The cache of models in RAM will
|
||||||
|
grow until this value is approached. Default is 2G.
|
||||||
|
'''
|
||||||
# prevent nasty-looking CLIP log message
|
# prevent nasty-looking CLIP log message
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.device = torch.device(device_type)
|
self.device = torch.device(device_type)
|
||||||
self.min_free_mem = min_free_mem
|
self.min_avail_mem = min_avail_mem
|
||||||
self.models = {}
|
self.models = {}
|
||||||
self.stack = [] # this is an LRU FIFO
|
self.stack = [] # this is an LRU FIFO
|
||||||
self.current_model = None
|
self.current_model = None
|
||||||
|
|
||||||
def get_model(self, model_name:str):
|
def get_model(self, model_name:str):
|
||||||
|
'''
|
||||||
|
Given a model named identified in models.yaml, return
|
||||||
|
the model object. If in RAM will load into GPU VRAM.
|
||||||
|
If on disk, will load from there.
|
||||||
|
'''
|
||||||
if model_name not in self.config:
|
if model_name not in self.config:
|
||||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.current_model != model_name:
|
if self.current_model != model_name:
|
||||||
@ -43,22 +56,42 @@ class ModelCache(object):
|
|||||||
|
|
||||||
if model_name in self.models:
|
if model_name in self.models:
|
||||||
requested_model = self.models[model_name]['model']
|
requested_model = self.models[model_name]['model']
|
||||||
self._model_from_cpu(requested_model)
|
print(f'>> Retrieving model {model_name} from system RAM cache')
|
||||||
|
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
|
||||||
width = self.models[model_name]['width']
|
width = self.models[model_name]['width']
|
||||||
height = self.models[model_name]['height']
|
height = self.models[model_name]['height']
|
||||||
|
hash = self.models[model_name]['hash']
|
||||||
else:
|
else:
|
||||||
self._check_memory()
|
self._check_memory()
|
||||||
requested_model, width, height = self._load_model(model_name)
|
try:
|
||||||
self.models[model_name] = {}
|
requested_model, width, height, hash = self._load_model(model_name)
|
||||||
self.models[model_name]['model'] = requested_model
|
self.models[model_name] = {}
|
||||||
self.models[model_name]['width'] = width
|
self.models[model_name]['model'] = requested_model
|
||||||
self.models[model_name]['height'] = height
|
self.models[model_name]['width'] = width
|
||||||
|
self.models[model_name]['height'] = height
|
||||||
|
self.models[model_name]['hash'] = hash
|
||||||
|
except Exception as e:
|
||||||
|
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||||
|
return {}
|
||||||
|
|
||||||
self.current_model = model_name
|
self.current_model = model_name
|
||||||
self._push_newest_model(model_name)
|
self._push_newest_model(model_name)
|
||||||
return requested_model, width, height
|
return {
|
||||||
|
'model':requested_model,
|
||||||
|
'width':width,
|
||||||
|
'height':height,
|
||||||
|
'hash': hash
|
||||||
|
}
|
||||||
|
|
||||||
def list_models(self):
|
def list_models(self) -> dict:
|
||||||
|
'''
|
||||||
|
Return a dict of models in the format:
|
||||||
|
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
||||||
|
'description': description,
|
||||||
|
},
|
||||||
|
model_name2: { etc }
|
||||||
|
'''
|
||||||
|
result = {}
|
||||||
for name in self.config:
|
for name in self.config:
|
||||||
try:
|
try:
|
||||||
description = self.config[name].description
|
description = self.config[name].description
|
||||||
@ -70,28 +103,26 @@ class ModelCache(object):
|
|||||||
status = 'cached'
|
status = 'cached'
|
||||||
else:
|
else:
|
||||||
status = 'not loaded'
|
status = 'not loaded'
|
||||||
print(f'{name:20s} {status:>10s} {description}')
|
result[name]={}
|
||||||
|
result[name]['status']=status
|
||||||
|
result[name]['description']=description
|
||||||
|
return result
|
||||||
|
|
||||||
|
def print_models(self):
|
||||||
|
'''
|
||||||
|
Print a table of models, their descriptions, and load status
|
||||||
|
'''
|
||||||
|
models = self.list_models()
|
||||||
|
for name in models:
|
||||||
|
print(f'{name:20s} {models[name]["status"]:>10s} {models[name]["description"]}')
|
||||||
|
|
||||||
def _check_memory(self):
|
def _check_memory(self):
|
||||||
free_memory = psutil.virtual_memory()[4]
|
avail_memory = psutil.virtual_memory()[1]
|
||||||
print(f'DEBUG: free memory = {free_memory}, min_mem = {self.min_free_mem}')
|
if avail_memory + AVG_MODEL_SIZE < self.min_avail_mem:
|
||||||
while free_memory + AVG_MODEL_SIZE < self.min_free_mem:
|
|
||||||
|
|
||||||
print(f'DEBUG: free memory = {free_memory}')
|
|
||||||
least_recent_model = self._pop_oldest_model()
|
least_recent_model = self._pop_oldest_model()
|
||||||
if least_recent_model is None:
|
if least_recent_model is not None:
|
||||||
return
|
del self.models[least_recent_model]
|
||||||
|
gc.collect()
|
||||||
print(f'DEBUG: clearing {least_recent_model} from cache (refcount = {getrefcount(self.models[least_recent_model]["model"])})')
|
|
||||||
del self.models[least_recent_model]['model']
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
new_free_memory = psutil.virtual_memory()[4]
|
|
||||||
if new_free_memory <= free_memory:
|
|
||||||
print(f'>> **Unable to free memory for model caching.**')
|
|
||||||
break;
|
|
||||||
free_memory = new_free_memory
|
|
||||||
|
|
||||||
|
|
||||||
def _load_model(self, model_name:str):
|
def _load_model(self, model_name:str):
|
||||||
@ -106,18 +137,20 @@ class ModelCache(object):
|
|||||||
width = mconfig.width
|
width = mconfig.width
|
||||||
height = mconfig.height
|
height = mconfig.height
|
||||||
|
|
||||||
print(f'>> Loading {model_name} weights from {weights}')
|
print(f'>> Loading {model_name} from {weights}')
|
||||||
|
|
||||||
# for usage statistics
|
# for usage statistics
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
# this does the work
|
# this does the work
|
||||||
c = OmegaConf.load(config)
|
c = OmegaConf.load(config)
|
||||||
with open(weights,'rb') as f:
|
with open(weights,'rb') as f:
|
||||||
weight_bytes = f.read()
|
weight_bytes = f.read()
|
||||||
self.model_hash = self._cached_sha256(weights,weight_bytes)
|
model_hash = self._cached_sha256(weights,weight_bytes)
|
||||||
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||||
del weight_bytes
|
del weight_bytes
|
||||||
sd = pl_sd['state_dict']
|
sd = pl_sd['state_dict']
|
||||||
@ -143,39 +176,40 @@ class ModelCache(object):
|
|||||||
'\n>> Current VRAM usage:'
|
'\n>> Current VRAM usage:'
|
||||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||||
)
|
)
|
||||||
return model, width, height
|
return model, width, height, model_hash
|
||||||
|
|
||||||
def unload_model(self, model_name:str):
|
def unload_model(self, model_name:str):
|
||||||
if model_name not in self.models:
|
if model_name not in self.models:
|
||||||
return
|
return
|
||||||
print(f'>> Unloading model {model_name}')
|
print(f'>> Caching model {model_name} in system RAM')
|
||||||
model = self.models[model_name]['model']
|
model = self.models[model_name]['model']
|
||||||
self._model_to_cpu(model)
|
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def _model_to_cpu(self,model):
|
def _model_to_cpu(self,model):
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
print(f'DEBUG: moving model to cpu')
|
|
||||||
model.first_stage_model.to('cpu')
|
model.first_stage_model.to('cpu')
|
||||||
model.cond_stage_model.to('cpu')
|
model.cond_stage_model.to('cpu')
|
||||||
model.model.to('cpu')
|
model.model.to('cpu')
|
||||||
|
return model.to('cpu')
|
||||||
|
|
||||||
def _model_from_cpu(self,model):
|
def _model_from_cpu(self,model):
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
print(f'DEBUG: moving model into {self.device.type}')
|
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
model.first_stage_model.to(self.device)
|
model.first_stage_model.to(self.device)
|
||||||
model.cond_stage_model.to(self.device)
|
model.cond_stage_model.to(self.device)
|
||||||
|
return model
|
||||||
|
|
||||||
def _pop_oldest_model(self):
|
def _pop_oldest_model(self):
|
||||||
'''
|
'''
|
||||||
Remove the first element of the FIFO, which ought
|
Remove the first element of the FIFO, which ought
|
||||||
to be the least recently accessed model.
|
to be the least recently accessed model. Do not
|
||||||
|
pop the last one, because it is in active use!
|
||||||
'''
|
'''
|
||||||
if len(self.stack)>0:
|
if len(self.stack) > 1:
|
||||||
self.stack.pop(0)
|
return self.stack.pop(0)
|
||||||
|
|
||||||
def _push_newest_model(self,model_name:str):
|
def _push_newest_model(self,model_name:str):
|
||||||
'''
|
'''
|
||||||
@ -187,7 +221,6 @@ class ModelCache(object):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
self.stack.append(model_name)
|
self.stack.append(model_name)
|
||||||
print(f'DEBUG, stack={self.stack}')
|
|
||||||
|
|
||||||
def _has_cuda(self):
|
def _has_cuda(self):
|
||||||
return self.device.type == 'cuda'
|
return self.device.type == 'cuda'
|
||||||
|
@ -47,7 +47,10 @@ COMMANDS = (
|
|||||||
'--skip_normalize','-x',
|
'--skip_normalize','-x',
|
||||||
'--log_tokenization','-t',
|
'--log_tokenization','-t',
|
||||||
'--hires_fix',
|
'--hires_fix',
|
||||||
'!fix','!fetch','!history','!search','!clear',
|
'!fix','!fetch','!history','!search','!clear','!models','!switch',
|
||||||
|
)
|
||||||
|
MODEL_COMMANDS = (
|
||||||
|
'!switch',
|
||||||
)
|
)
|
||||||
IMG_PATH_COMMANDS = (
|
IMG_PATH_COMMANDS = (
|
||||||
'--outdir[=\s]',
|
'--outdir[=\s]',
|
||||||
@ -63,8 +66,9 @@ IMG_FILE_COMMANDS=(
|
|||||||
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
|
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
|
||||||
|
|
||||||
class Completer(object):
|
class Completer(object):
|
||||||
def __init__(self, options):
|
def __init__(self, options, models=[]):
|
||||||
self.options = sorted(options)
|
self.options = sorted(options)
|
||||||
|
self.models = sorted(models)
|
||||||
self.seeds = set()
|
self.seeds = set()
|
||||||
self.matches = list()
|
self.matches = list()
|
||||||
self.default_dir = None
|
self.default_dir = None
|
||||||
@ -88,6 +92,9 @@ class Completer(object):
|
|||||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||||
self.matches= self._seed_completions(text,state)
|
self.matches= self._seed_completions(text,state)
|
||||||
|
|
||||||
|
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||||
|
self.matches= self._model_completions(text,state)
|
||||||
|
|
||||||
# This is the first time for this text, so build a match list.
|
# This is the first time for this text, so build a match list.
|
||||||
elif text:
|
elif text:
|
||||||
self.matches = [
|
self.matches = [
|
||||||
@ -188,6 +195,21 @@ class Completer(object):
|
|||||||
matches.sort()
|
matches.sort()
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
def _model_completions(self, text, state):
|
||||||
|
m = re.search('(!switch\s+)(\w*)',text)
|
||||||
|
if m:
|
||||||
|
switch = m.groups()[0]
|
||||||
|
partial = m.groups()[1]
|
||||||
|
else:
|
||||||
|
switch = ''
|
||||||
|
partial = text
|
||||||
|
matches = list()
|
||||||
|
for s in self.models:
|
||||||
|
if s.startswith(partial):
|
||||||
|
matches.append(switch+s)
|
||||||
|
matches.sort()
|
||||||
|
return matches
|
||||||
|
|
||||||
def _pre_input_hook(self):
|
def _pre_input_hook(self):
|
||||||
if self.linebuffer:
|
if self.linebuffer:
|
||||||
readline.insert_text(self.linebuffer)
|
readline.insert_text(self.linebuffer)
|
||||||
@ -266,9 +288,9 @@ class DummyCompleter(Completer):
|
|||||||
def set_line(self,line):
|
def set_line(self,line):
|
||||||
print(f'# {line}')
|
print(f'# {line}')
|
||||||
|
|
||||||
def get_completer(opt:Args)->Completer:
|
def get_completer(opt:Args, models=[])->Completer:
|
||||||
if readline_available:
|
if readline_available:
|
||||||
completer = Completer(COMMANDS)
|
completer = Completer(COMMANDS,models)
|
||||||
|
|
||||||
readline.set_completer(
|
readline.set_completer(
|
||||||
completer.complete
|
completer.complete
|
||||||
|
@ -106,7 +106,7 @@ class DDPM(pl.LightningModule):
|
|||||||
], 'currently only supporting "eps" and "x0"'
|
], 'currently only supporting "eps" and "x0"'
|
||||||
self.parameterization = parameterization
|
self.parameterization = parameterization
|
||||||
print(
|
print(
|
||||||
f'{self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
|
f' >> {self.__class__.__name__}: Running in {self.parameterization}-prediction mode'
|
||||||
)
|
)
|
||||||
self.cond_stage_model = None
|
self.cond_stage_model = None
|
||||||
self.clip_denoised = clip_denoised
|
self.clip_denoised = clip_denoised
|
||||||
|
@ -245,7 +245,7 @@ class AttnBlock(nn.Module):
|
|||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla"):
|
def make_attn(in_channels, attn_type="vanilla"):
|
||||||
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
print(f" >> Making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||||
if attn_type == "vanilla":
|
if attn_type == "vanilla":
|
||||||
return AttnBlock(in_channels)
|
return AttnBlock(in_channels)
|
||||||
elif attn_type == "none":
|
elif attn_type == "none":
|
||||||
@ -521,7 +521,7 @@ class Decoder(nn.Module):
|
|||||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||||
print("Working with z of shape {} = {} dimensions.".format(
|
print(" >> Working with z of shape {} = {} dimensions.".format(
|
||||||
self.z_shape, np.prod(self.z_shape)))
|
self.z_shape, np.prod(self.z_shape)))
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
|
@ -75,7 +75,7 @@ def count_params(model, verbose=False):
|
|||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(
|
||||||
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
f' >> {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
||||||
)
|
)
|
||||||
return total_params
|
return total_params
|
||||||
|
|
||||||
|
@ -16,8 +16,6 @@ from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
|||||||
from ldm.invoke.image_util import make_grid
|
from ldm.invoke.image_util import make_grid
|
||||||
from ldm.invoke.log import write_log
|
from ldm.invoke.log import write_log
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from backend.invoke_ai_web_server import InvokeAIWebServer
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Initialize command-line parsers and the diffusion model"""
|
"""Initialize command-line parsers and the diffusion model"""
|
||||||
@ -33,7 +31,7 @@ def main():
|
|||||||
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
|
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
print('* Initializing, be patient...\n')
|
print('* Initializing, be patient...')
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
# these two lines prevent a horrible warning message from appearing
|
||||||
@ -42,45 +40,7 @@ def main():
|
|||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
# Loading Face Restoration and ESRGAN Modules
|
# Loading Face Restoration and ESRGAN Modules
|
||||||
try:
|
gfpgan,codeformer,esrgan = load_face_restoration(opt)
|
||||||
gfpgan, codeformer, esrgan = None, None, None
|
|
||||||
if opt.restore or opt.esrgan:
|
|
||||||
from ldm.invoke.restoration import Restoration
|
|
||||||
restoration = Restoration()
|
|
||||||
if opt.restore:
|
|
||||||
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_dir, opt.gfpgan_model_path)
|
|
||||||
else:
|
|
||||||
print('>> Face restoration disabled')
|
|
||||||
if opt.esrgan:
|
|
||||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
|
||||||
else:
|
|
||||||
print('>> Upscaling disabled')
|
|
||||||
else:
|
|
||||||
print('>> Face restoration and upscaling disabled')
|
|
||||||
except (ModuleNotFoundError, ImportError):
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
|
||||||
|
|
||||||
# creating a simple text2image object with a handful of
|
|
||||||
# defaults passed on the command line.
|
|
||||||
# additional parameters will be added (or overriden) during
|
|
||||||
# the user input loop
|
|
||||||
try:
|
|
||||||
gen = Generate(
|
|
||||||
conf = opt.conf,
|
|
||||||
model = opt.model,
|
|
||||||
sampler_name = opt.sampler_name,
|
|
||||||
embedding_path = opt.embedding_path,
|
|
||||||
full_precision = opt.full_precision,
|
|
||||||
precision = opt.precision,
|
|
||||||
gfpgan=gfpgan,
|
|
||||||
codeformer=codeformer,
|
|
||||||
esrgan=esrgan,
|
|
||||||
free_gpu_mem=opt.free_gpu_mem,
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, IOError, KeyError) as e:
|
|
||||||
print(f'{e}. Aborting.')
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
# make sure the output directory exists
|
# make sure the output directory exists
|
||||||
if not os.path.exists(opt.outdir):
|
if not os.path.exists(opt.outdir):
|
||||||
@ -100,6 +60,24 @@ def main():
|
|||||||
print(f'{e}. Aborting.')
|
print(f'{e}. Aborting.')
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
|
# creating a Generate object:
|
||||||
|
try:
|
||||||
|
gen = Generate(
|
||||||
|
conf = opt.conf,
|
||||||
|
model = opt.model,
|
||||||
|
sampler_name = opt.sampler_name,
|
||||||
|
embedding_path = opt.embedding_path,
|
||||||
|
full_precision = opt.full_precision,
|
||||||
|
precision = opt.precision,
|
||||||
|
gfpgan=gfpgan,
|
||||||
|
codeformer=codeformer,
|
||||||
|
esrgan=esrgan,
|
||||||
|
free_gpu_mem=opt.free_gpu_mem,
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, IOError, KeyError) as e:
|
||||||
|
print(f'{e}. Aborting.')
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
if opt.seamless:
|
if opt.seamless:
|
||||||
print(">> changed to seamless tiling mode")
|
print(">> changed to seamless tiling mode")
|
||||||
|
|
||||||
@ -124,12 +102,12 @@ def main_loop(gen, opt, infile):
|
|||||||
done = False
|
done = False
|
||||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||||
last_results = list()
|
last_results = list()
|
||||||
model_config = OmegaConf.load(opt.conf)[opt.model]
|
model_config = OmegaConf.load(opt.conf)
|
||||||
|
|
||||||
# The readline completer reads history from the .dream_history file located in the
|
# The readline completer reads history from the .dream_history file located in the
|
||||||
# output directory specified at the time of script launch. We do not currently support
|
# output directory specified at the time of script launch. We do not currently support
|
||||||
# changing the history file midstream when the output directory is changed.
|
# changing the history file midstream when the output directory is changed.
|
||||||
completer = get_completer(opt)
|
completer = get_completer(opt, models=list(model_config.keys()))
|
||||||
output_cntr = completer.get_current_history_length()+1
|
output_cntr = completer.get_current_history_length()+1
|
||||||
|
|
||||||
# os.pathconf is not available on Windows
|
# os.pathconf is not available on Windows
|
||||||
@ -173,6 +151,16 @@ def main_loop(gen, opt, infile):
|
|||||||
command = command.replace('!fix ','',1)
|
command = command.replace('!fix ','',1)
|
||||||
operation = 'postprocess'
|
operation = 'postprocess'
|
||||||
|
|
||||||
|
elif subcommand.startswith('switch'):
|
||||||
|
model_name = command.replace('!switch ','',1)
|
||||||
|
gen.set_model(model_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif subcommand.startswith('models'):
|
||||||
|
model_name = command.replace('!models ','',1)
|
||||||
|
gen.model_cache.print_models()
|
||||||
|
continue
|
||||||
|
|
||||||
elif subcommand.startswith('fetch'):
|
elif subcommand.startswith('fetch'):
|
||||||
file_path = command.replace('!fetch ','',1)
|
file_path = command.replace('!fetch ','',1)
|
||||||
retrieve_dream_command(opt,file_path,completer)
|
retrieve_dream_command(opt,file_path,completer)
|
||||||
@ -218,9 +206,9 @@ def main_loop(gen, opt, infile):
|
|||||||
|
|
||||||
# width and height are set by model if not specified
|
# width and height are set by model if not specified
|
||||||
if not opt.width:
|
if not opt.width:
|
||||||
opt.width = model_config.width
|
opt.width = gen.width
|
||||||
if not opt.height:
|
if not opt.height:
|
||||||
opt.height = model_config.height
|
opt.height = gen.height
|
||||||
|
|
||||||
# retrieve previous value of init image if requested
|
# retrieve previous value of init image if requested
|
||||||
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
|
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
|
||||||
@ -509,6 +497,7 @@ def get_next_command(infile=None) -> str: # command string
|
|||||||
|
|
||||||
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
|
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
|
||||||
print('\n* --web was specified, starting web server...')
|
print('\n* --web was specified, starting web server...')
|
||||||
|
from backend.invoke_ai_web_server import InvokeAIWebServer
|
||||||
# Change working directory to the stable-diffusion directory
|
# Change working directory to the stable-diffusion directory
|
||||||
os.chdir(
|
os.chdir(
|
||||||
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
@ -547,6 +536,28 @@ def split_variations(variations_string) -> list:
|
|||||||
else:
|
else:
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
def load_face_restoration(opt):
|
||||||
|
try:
|
||||||
|
gfpgan, codeformer, esrgan = None, None, None
|
||||||
|
if opt.restore or opt.esrgan:
|
||||||
|
from ldm.invoke.restoration import Restoration
|
||||||
|
restoration = Restoration()
|
||||||
|
if opt.restore:
|
||||||
|
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_dir, opt.gfpgan_model_path)
|
||||||
|
else:
|
||||||
|
print('>> Face restoration disabled')
|
||||||
|
if opt.esrgan:
|
||||||
|
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||||
|
else:
|
||||||
|
print('>> Upscaling disabled')
|
||||||
|
else:
|
||||||
|
print('>> Face restoration and upscaling disabled')
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||||
|
return gfpgan,codeformer,esrgan
|
||||||
|
|
||||||
|
|
||||||
def retrieve_dream_command(opt,file_path,completer):
|
def retrieve_dream_command(opt,file_path,completer):
|
||||||
'''
|
'''
|
||||||
Given a full or partial path to a previously-generated image file,
|
Given a full or partial path to a previously-generated image file,
|
||||||
|
Loading…
Reference in New Issue
Block a user