mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
* add whole <style token> to vocab for concept library embeddings * add ability to load multiple concept .bin files * make --log_tokenization respect custom tokens * start working on concept downloading system * preliminary support for dynamic loading and merging of multiple embedded models - The embedding_manager is now enhanced with ldm.invoke.concepts_lib, which handles dynamic downloading and caching of embedded models from the Hugging Face concepts library (https://huggingface.co/sd-concepts-library) - Downloading of a embedded model is triggered by the presence of one or more <concept> tags in the prompt. - Once the embedded model is downloaded, its trigger phrase will be loaded into the embedding manager and the prompt's <concept> tag will be replaced with the <trigger_phrase> - The downloaded model stays on disk for fast loading later. - The CLI autocomplete will complete partial <concept> tags for you. Type a '<' and hit tab to get all ~700 concepts. BUGS AND LIMITATIONS: - MODEL NAME VS TRIGGER PHRASE You must use the name of the concept embed model from the SD library, and not the trigger phrase itself. Usually these are the same, but not always. For example, the model named "hoi4-leaders" corresponds to the trigger "<HOI4-Leader>" One reason for this design choice is that there is no apparent constraint on the uniqueness of the trigger phrases and one trigger phrase may map onto multiple models. So we use the model name instead. The second reason is that there is no way I know of to search Hugging Face for models with certain trigger phrases. So we'd have to download all 700 models to index the phrases. The problem this presents is that this may confuse users, who will want to reuse prompts from distributions that use the trigger phrase directly. Usually this will work, but not always. - WON'T WORK ON A FIREWALLED SYSTEM If the host running IAI has no internet connection, it can't download the concept libraries. I will add a script that allows users to preload a list of concept models. - BUG IN PROMPT REPLACEMENT WHEN MODEL NOT FOUND There's a small bug that occurs when the user provides an invalid model name. The <concept> gets replaced with <None> in the prompt. * fix loading .pt embeddings; allow multi-vector embeddings; warn on dupes * simplify replacement logic and remove cuda assumption * download list of concepts from hugging face * remove misleading customization of '*' placeholder the existing code as-is did not do anything; unclear what it was supposed to do. the obvious alternative -- setting using 'placeholder_strings' instead of 'placeholder_tokens' to match model.params.personalization_config.params.placeholder_strings -- caused a crash. i think this is because the passed string also needed to be handed over on init of the PersonalizedBase as the 'placeholder_token' argument. this is weird config dict magic and i don't want to touch it. put a breakpoint in personalzied.py line 116 (top of PersonalizedBase.__init__) if you want to have a crack at it yourself. * address all the issues raised by damian0815 in review of PR #1526 * actually resize the token_embeddings * multiple improvements to the concept loader based on code reviews 1. Activated the --embedding_directory option (alias --embedding_path) to load a single embedding or an entire directory of embeddings at startup time. 2. Can turn off automatic loading of embeddings using --no-embeddings. 3. Embedding checkpoints are scanned with the pickle scanner. 4. More informative error messages when a concept can't be loaded due either to a 404 not found error or a network error. * autocomplete terms end with ">" now * fix startup error and network unreachable 1. If the .invokeai file does not contain the --root and --outdir options, invoke.py will now fix it. 2. Catch and handle network problems when downloading hugging face textual inversion concepts. * fix misformatted error string Co-authored-by: Damian Stewart <d@damianstewart.com>
416 lines
15 KiB
Python
416 lines
15 KiB
Python
'''
|
|
Manage a cache of Stable Diffusion model files for fast switching.
|
|
They are moved between GPU and CPU as necessary. If CPU memory falls
|
|
below a preset minimum, the least recently used model will be
|
|
cleared and loaded from disk when next needed.
|
|
'''
|
|
|
|
import torch
|
|
import os
|
|
import io
|
|
import time
|
|
import gc
|
|
import hashlib
|
|
import psutil
|
|
import sys
|
|
import transformers
|
|
import traceback
|
|
import textwrap
|
|
import contextlib
|
|
from typing import Union
|
|
from omegaconf import OmegaConf
|
|
from omegaconf.errors import ConfigAttributeError
|
|
from ldm.util import instantiate_from_config, ask_user
|
|
from ldm.invoke.globals import Globals
|
|
from picklescan.scanner import scan_file_path
|
|
|
|
DEFAULT_MAX_MODELS=2
|
|
|
|
class ModelCache(object):
|
|
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
|
'''
|
|
Initialize with the path to the models.yaml config file,
|
|
the torch device type, and precision. The optional
|
|
min_avail_mem argument specifies how much unused system
|
|
(CPU) memory to preserve. The cache of models in RAM will
|
|
grow until this value is approached. Default is 2G.
|
|
'''
|
|
# prevent nasty-looking CLIP log message
|
|
transformers.logging.set_verbosity_error()
|
|
self.config = config
|
|
self.precision = precision
|
|
self.device = torch.device(device_type)
|
|
self.max_loaded_models = max_loaded_models
|
|
self.models = {}
|
|
self.stack = [] # this is an LRU FIFO
|
|
self.current_model = None
|
|
|
|
def valid_model(self, model_name:str)->bool:
|
|
'''
|
|
Given a model name, returns True if it is a valid
|
|
identifier.
|
|
'''
|
|
return model_name in self.config
|
|
|
|
def get_model(self, model_name:str):
|
|
'''
|
|
Given a model named identified in models.yaml, return
|
|
the model object. If in RAM will load into GPU VRAM.
|
|
If on disk, will load from there.
|
|
'''
|
|
if not self.valid_model(model_name):
|
|
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
|
return self.current_model
|
|
|
|
if self.current_model != model_name:
|
|
if model_name not in self.models: # make room for a new one
|
|
self._make_cache_room()
|
|
self.offload_model(self.current_model)
|
|
|
|
if model_name in self.models:
|
|
requested_model = self.models[model_name]['model']
|
|
print(f'>> Retrieving model {model_name} from system RAM cache')
|
|
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
|
|
width = self.models[model_name]['width']
|
|
height = self.models[model_name]['height']
|
|
hash = self.models[model_name]['hash']
|
|
|
|
else: # we're about to load a new model, so potentially offload the least recently used one
|
|
try:
|
|
requested_model, width, height, hash = self._load_model(model_name)
|
|
self.models[model_name] = {}
|
|
self.models[model_name]['model'] = requested_model
|
|
self.models[model_name]['width'] = width
|
|
self.models[model_name]['height'] = height
|
|
self.models[model_name]['hash'] = hash
|
|
|
|
except Exception as e:
|
|
print(f'** model {model_name} could not be loaded: {str(e)}')
|
|
print(traceback.format_exc())
|
|
assert self.current_model,'** FATAL: no current model to restore to'
|
|
print(f'** restoring {self.current_model}')
|
|
self.get_model(self.current_model)
|
|
return
|
|
|
|
self.current_model = model_name
|
|
self._push_newest_model(model_name)
|
|
return {
|
|
'model':requested_model,
|
|
'width':width,
|
|
'height':height,
|
|
'hash': hash
|
|
}
|
|
|
|
def default_model(self) -> str:
|
|
'''
|
|
Returns the name of the default model, or None
|
|
if none is defined.
|
|
'''
|
|
for model_name in self.config:
|
|
if self.config[model_name].get('default'):
|
|
return model_name
|
|
|
|
def set_default_model(self,model_name:str) -> None:
|
|
'''
|
|
Set the default model. The change will not take
|
|
effect until you call model_cache.commit()
|
|
'''
|
|
assert model_name in self.models,f"unknown model '{model_name}'"
|
|
|
|
config = self.config
|
|
for model in config:
|
|
config[model].pop('default',None)
|
|
config[model_name]['default'] = True
|
|
|
|
def list_models(self) -> dict:
|
|
'''
|
|
Return a dict of models in the format:
|
|
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
|
'description': description,
|
|
},
|
|
model_name2: { etc }
|
|
'''
|
|
result = dict()
|
|
for name in self.config:
|
|
try:
|
|
description = self.config[name].description
|
|
except ConfigAttributeError:
|
|
description = '<no description>'
|
|
|
|
if self.current_model == name:
|
|
status = 'active'
|
|
elif name in self.models:
|
|
status = 'cached'
|
|
else:
|
|
status = 'not loaded'
|
|
|
|
result[name]={
|
|
'status' : status,
|
|
'description' : description
|
|
}
|
|
return result
|
|
|
|
def print_models(self) -> None:
|
|
'''
|
|
Print a table of models, their descriptions, and load status
|
|
'''
|
|
models = self.list_models()
|
|
for name in models:
|
|
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}'
|
|
if models[name]['status'] == 'active':
|
|
line = f'\033[1m{line}\033[0m'
|
|
print(line)
|
|
|
|
def del_model(self, model_name:str) -> None:
|
|
'''
|
|
Delete the named model.
|
|
'''
|
|
omega = self.config
|
|
del omega[model_name]
|
|
if model_name in self.stack:
|
|
self.stack.remove(model_name)
|
|
|
|
def add_model(self, model_name:str, model_attributes:dict, clobber=False) -> None:
|
|
'''
|
|
Update the named model with a dictionary of attributes. Will fail with an
|
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
|
On a successful update, the config will be changed in memory and the
|
|
method will return True. Will fail with an assertion error if provided
|
|
attributes are incorrect or the model name is missing.
|
|
'''
|
|
for field in ('description','weights','height','width','config'):
|
|
assert field in model_attributes, f'required field {field} is missing'
|
|
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
|
|
|
|
omega = self.config
|
|
config = omega[model_name] if model_name in omega else {}
|
|
for field in model_attributes:
|
|
config[field] = model_attributes[field]
|
|
|
|
omega[model_name] = config
|
|
if clobber:
|
|
self._invalidate_cached_model(model_name)
|
|
|
|
def _load_model(self, model_name:str):
|
|
"""Load and initialize the model from configuration variables passed at object creation time"""
|
|
if model_name not in self.config:
|
|
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
|
|
|
mconfig = self.config[model_name]
|
|
config = mconfig.config
|
|
weights = mconfig.weights
|
|
vae = mconfig.get('vae')
|
|
width = mconfig.width
|
|
height = mconfig.height
|
|
|
|
if not os.path.isabs(weights):
|
|
weights = os.path.normpath(os.path.join(Globals.root,weights))
|
|
# scan model
|
|
self.scan_model(model_name, weights)
|
|
|
|
print(f'>> Loading {model_name} from {weights}')
|
|
|
|
# for usage statistics
|
|
if self._has_cuda():
|
|
torch.cuda.reset_peak_memory_stats()
|
|
torch.cuda.empty_cache()
|
|
|
|
tic = time.time()
|
|
|
|
# this does the work
|
|
if not os.path.isabs(config):
|
|
config = os.path.join(Globals.root,config)
|
|
omega_config = OmegaConf.load(config)
|
|
with open(weights,'rb') as f:
|
|
weight_bytes = f.read()
|
|
model_hash = self._cached_sha256(weights,weight_bytes)
|
|
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
|
del weight_bytes
|
|
sd = sd['state_dict']
|
|
model = instantiate_from_config(omega_config.model)
|
|
model.load_state_dict(sd, strict=False)
|
|
|
|
if self.precision == 'float16':
|
|
print(' | Using faster float16 precision')
|
|
model.to(torch.float16)
|
|
else:
|
|
print(' | Using more accurate float32 precision')
|
|
|
|
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
|
if vae:
|
|
if not os.path.isabs(vae):
|
|
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
|
if os.path.exists(vae):
|
|
print(f' | Loading VAE weights from: {vae}')
|
|
vae_ckpt = torch.load(vae, map_location="cpu")
|
|
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
|
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
|
else:
|
|
print(f' | VAE file {vae} not found. Skipping.')
|
|
|
|
model.to(self.device)
|
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
|
model.cond_stage_model.device = self.device
|
|
|
|
model.eval()
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
|
module._orig_padding_mode = module.padding_mode
|
|
|
|
# usage statistics
|
|
toc = time.time()
|
|
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
|
|
|
if self._has_cuda():
|
|
print(
|
|
'>> Max VRAM used to load the model:',
|
|
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
|
'\n>> Current VRAM usage:'
|
|
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
|
)
|
|
|
|
return model, width, height, model_hash
|
|
|
|
def offload_model(self, model_name:str) -> None:
|
|
'''
|
|
Offload the indicated model to CPU. Will call
|
|
_make_cache_room() to free space if needed.
|
|
'''
|
|
if model_name not in self.models:
|
|
return
|
|
|
|
print(f'>> Offloading {model_name} to CPU')
|
|
model = self.models[model_name]['model']
|
|
self.models[model_name]['model'] = self._model_to_cpu(model)
|
|
|
|
gc.collect()
|
|
if self._has_cuda():
|
|
torch.cuda.empty_cache()
|
|
|
|
def scan_model(self, model_name, checkpoint):
|
|
# scan model
|
|
print(f'>> Scanning Model: {model_name}')
|
|
scan_result = scan_file_path(checkpoint)
|
|
if scan_result.infected_files != 0:
|
|
if scan_result.infected_files == 1:
|
|
print(f'\n### Issues Found In Model: {scan_result.issues_count}')
|
|
print('### WARNING: The model you are trying to load seems to be infected.')
|
|
print('### For your safety, InvokeAI will not load this model.')
|
|
print('### Please use checkpoints from trusted sources.')
|
|
print("### Exiting InvokeAI")
|
|
sys.exit()
|
|
else:
|
|
print('\n### WARNING: InvokeAI was unable to scan the model you are using.')
|
|
model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n'])
|
|
if model_safe_check_fail.lower() != 'y':
|
|
print("### Exiting InvokeAI")
|
|
sys.exit()
|
|
else:
|
|
print('>> Model Scanned. OK!!')
|
|
|
|
def _make_cache_room(self) -> None:
|
|
num_loaded_models = len(self.models)
|
|
if num_loaded_models >= self.max_loaded_models:
|
|
least_recent_model = self._pop_oldest_model()
|
|
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
|
|
if least_recent_model is not None:
|
|
del self.models[least_recent_model]
|
|
gc.collect()
|
|
|
|
def print_vram_usage(self) -> None:
|
|
if self._has_cuda:
|
|
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
|
|
|
def commit(self,config_file_path:str) -> None:
|
|
'''
|
|
Write current configuration out to the indicated file.
|
|
'''
|
|
yaml_str = OmegaConf.to_yaml(self.config)
|
|
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
|
|
with open(tmpfile, 'w') as outfile:
|
|
outfile.write(self.preamble())
|
|
outfile.write(yaml_str)
|
|
os.replace(tmpfile,config_file_path)
|
|
|
|
def preamble(self) -> str:
|
|
'''
|
|
Returns the preamble for the config file.
|
|
'''
|
|
return textwrap.dedent('''\
|
|
# This file describes the alternative machine learning models
|
|
# available to InvokeAI script.
|
|
#
|
|
# To add a new model, follow the examples below. Each
|
|
# model requires a model config file, a weights file,
|
|
# and the width and height of the images it
|
|
# was trained on.
|
|
''')
|
|
|
|
def _invalidate_cached_model(self,model_name:str) -> None:
|
|
self.offload_model(model_name)
|
|
if model_name in self.stack:
|
|
self.stack.remove(model_name)
|
|
self.models.pop(model_name,None)
|
|
|
|
def _model_to_cpu(self,model):
|
|
if self.device != 'cpu':
|
|
model.cond_stage_model.device = 'cpu'
|
|
model.first_stage_model.to('cpu')
|
|
model.cond_stage_model.to('cpu')
|
|
model.model.to('cpu')
|
|
return model.to('cpu')
|
|
else:
|
|
return model
|
|
|
|
def _model_from_cpu(self,model):
|
|
if self.device != 'cpu':
|
|
model.to(self.device)
|
|
model.first_stage_model.to(self.device)
|
|
model.cond_stage_model.to(self.device)
|
|
model.cond_stage_model.device = self.device
|
|
return model
|
|
|
|
def _pop_oldest_model(self):
|
|
'''
|
|
Remove the first element of the FIFO, which ought
|
|
to be the least recently accessed model. Do not
|
|
pop the last one, because it is in active use!
|
|
'''
|
|
return self.stack.pop(0)
|
|
|
|
def _push_newest_model(self,model_name:str) -> None:
|
|
'''
|
|
Maintain a simple FIFO. First element is always the
|
|
least recent, and last element is always the most recent.
|
|
'''
|
|
with contextlib.suppress(ValueError):
|
|
self.stack.remove(model_name)
|
|
self.stack.append(model_name)
|
|
|
|
def _has_cuda(self) -> bool:
|
|
return self.device.type == 'cuda'
|
|
|
|
def _cached_sha256(self,path,data) -> Union[str, bytes]:
|
|
dirname = os.path.dirname(path)
|
|
basename = os.path.basename(path)
|
|
base, _ = os.path.splitext(basename)
|
|
hashpath = os.path.join(dirname,base+'.sha256')
|
|
|
|
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
|
with open(hashpath) as f:
|
|
hash = f.read()
|
|
return hash
|
|
|
|
print(f'>> Calculating sha256 hash of weights file')
|
|
tic = time.time()
|
|
sha = hashlib.sha256()
|
|
sha.update(data)
|
|
hash = sha.hexdigest()
|
|
toc = time.time()
|
|
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
|
|
|
with open(hashpath,'w') as f:
|
|
f.write(hash)
|
|
return hash
|