mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
enable fast switching between models in invoke.py
- This PR enables two new commands in the invoke.py script !models -- list the available models and their cache status !switch <model> -- switch to the indicated model Example: invoke> !models laion400m not loaded Latent Diffusion LAION400M model stable-diffusion-1.4 active Stable Diffusion inference model version 1.4 waifu-1.3 cached Waifu anime model version 1.3 invoke> !switch waifu-1.3 >> Caching model stable-diffusion-1.4 in system RAM >> Retrieving model waifu-1.3 from system RAM cache The name and descriptions of the models are taken from `config/models.yaml`. A future enhancement to `model_cache.py` will be to enable new model stanzas to be added to the file programmatically. This will be useful for the WebGUI. More details: - Use fast switching algorithm described in PR #948 - Models are selected using their configuration stanza name given in models.yaml. - To avoid filling up CPU RAM with cached models, this PR implements an LRU cache that monitors available CPU RAM. - The caching code allows the minimum value of available RAM to be adjusted, but invoke.py does not currently have a command-line argument that allows you to set it. The minimum free RAM is arbitrarily set to 2 GB. - Add optional description field to configs/models.yaml Unrelated fixes: - Added ">>" to CompViz model loading messages in order to make user experience more consistent. - When generating an image greater than defaults, will only warn about possible VRAM filling the first time. - Fixed bug that was causing help message to be printed twice. This involved moving the import line for the web backend into the section where it is called. Coauthored by: @ArDiouscuros
This commit is contained in:
152
ldm/generate.py
152
ldm/generate.py
@ -33,6 +33,7 @@ from ldm.invoke.args import metadata_from_png
|
||||
from ldm.invoke.image_util import InitImageResizer
|
||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||
from ldm.invoke.conditioning import get_uc_and_c
|
||||
from ldm.invoke.model_cache import ModelCache
|
||||
|
||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||
|
||||
@ -123,12 +124,11 @@ class Generate:
|
||||
esrgan=None,
|
||||
free_gpu_mem=False,
|
||||
):
|
||||
models = OmegaConf.load(conf)
|
||||
mconfig = models[model]
|
||||
self.weights = mconfig.weights if weights is None else weights
|
||||
self.config = mconfig.config if config is None else config
|
||||
self.height = mconfig.height
|
||||
self.width = mconfig.width
|
||||
mconfig = OmegaConf.load(conf)
|
||||
self.model_name = model
|
||||
self.height = None
|
||||
self.width = None
|
||||
self.model_cache = None
|
||||
self.iterations = 1
|
||||
self.steps = 50
|
||||
self.cfg_scale = 7.5
|
||||
@ -139,6 +139,7 @@ class Generate:
|
||||
self.seamless = False
|
||||
self.embedding_path = embedding_path
|
||||
self.model = None # empty for now
|
||||
self.model_hash = None
|
||||
self.sampler = None
|
||||
self.device = None
|
||||
self.session_peakmem = None
|
||||
@ -149,6 +150,7 @@ class Generate:
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
@ -164,6 +166,9 @@ class Generate:
|
||||
if self.precision == 'auto':
|
||||
self.precision = choose_precision(self.device)
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||
|
||||
# for VRAM usage statistics
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||
transformers.logging.set_verbosity_error()
|
||||
@ -294,7 +299,12 @@ class Generate:
|
||||
with_variations = [] if with_variations is None else with_variations
|
||||
|
||||
# will instantiate the model or return it from cache
|
||||
model = self.load_model()
|
||||
model = self.set_model(self.model_name)
|
||||
|
||||
# self.width and self.height are set by set_model()
|
||||
# to the width and height of the image training set
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
@ -584,8 +594,9 @@ class Generate:
|
||||
# this returns a torch tensor
|
||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||
|
||||
if (image.width * image.height) > (self.width * self.height):
|
||||
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||
self.size_matters = False
|
||||
|
||||
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
||||
|
||||
@ -635,29 +646,47 @@ class Generate:
|
||||
return self.generators['inpaint']
|
||||
|
||||
def load_model(self):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if self.model is None:
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
try:
|
||||
model = self._load_model_from_config(self.config, self.weights)
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(
|
||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
||||
)
|
||||
self.model = model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
self.model.cond_stage_model.device = self.device
|
||||
except AttributeError as e:
|
||||
print(f'>> Error loading model. {str(e)}', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
raise SystemExit from e
|
||||
'''
|
||||
preload model identified in self.model_name
|
||||
'''
|
||||
self.set_model(self.model_name)
|
||||
|
||||
self._set_sampler()
|
||||
def set_model(self,model_name):
|
||||
"""
|
||||
Given the name of a model defined in models.yaml, will load and initialize it
|
||||
and return the model object. Previously-used models will be cached.
|
||||
"""
|
||||
if self.model_name == model_name and self.model is not None:
|
||||
return self.model
|
||||
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m._orig_padding_mode = m.padding_mode
|
||||
model_data = self.model_cache.get_model(model_name)
|
||||
if model_data is None or len(model_data) == 0:
|
||||
print(f'** Model switch failed **')
|
||||
return self.model
|
||||
|
||||
self.model = model_data['model']
|
||||
self.width = model_data['width']
|
||||
self.height= model_data['height']
|
||||
self.model_hash = model_data['hash']
|
||||
|
||||
# uncache generators so they pick up new models
|
||||
self.generators = {}
|
||||
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(
|
||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
||||
)
|
||||
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
self.model.cond_stage_model.device = self.device
|
||||
self._set_sampler()
|
||||
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m._orig_padding_mode = m.padding_mode
|
||||
|
||||
self.model_name = model_name
|
||||
return self.model
|
||||
|
||||
def correct_colors(self,
|
||||
@ -761,53 +790,6 @@ class Generate:
|
||||
|
||||
print(msg)
|
||||
|
||||
# Be warned: config is the path to the model config file, not the invoke conf file!
|
||||
# Also note that we can get config and weights from self, so why do we need to
|
||||
# pass them as args?
|
||||
def _load_model_from_config(self, config, weights):
|
||||
print(f'>> Loading model from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
device_type = choose_torch_device()
|
||||
if device_type == 'cuda':
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
c = OmegaConf.load(config)
|
||||
with open(weights,'rb') as f:
|
||||
weight_bytes = f.read()
|
||||
self.model_hash = self._cached_sha256(weights,weight_bytes)
|
||||
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||
del weight_bytes
|
||||
sd = pl_sd['state_dict']
|
||||
model = instantiate_from_config(c.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if self.precision == 'float16':
|
||||
print('>> Using faster float16 precision')
|
||||
model.to(torch.float16)
|
||||
else:
|
||||
print('>> Using more accurate float32 precision')
|
||||
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(
|
||||
f'>> Model loaded in', '%4.2fs' % (toc - tic)
|
||||
)
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def _load_img(self, img, width, height)->Image:
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
@ -951,26 +933,6 @@ class Generate:
|
||||
def _has_cuda(self):
|
||||
return self.device.type == 'cuda'
|
||||
|
||||
def _cached_sha256(self,path,data):
|
||||
dirname = os.path.dirname(path)
|
||||
basename = os.path.basename(path)
|
||||
base, _ = os.path.splitext(basename)
|
||||
hashpath = os.path.join(dirname,base+'.sha256')
|
||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
print(f'>> Calculating sha256 hash of weights file')
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
sha.update(data)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
||||
with open(hashpath,'w') as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
def write_intermediate_images(self,modulus,path):
|
||||
counter = -1
|
||||
if not os.path.exists(path):
|
||||
|
Reference in New Issue
Block a user