mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into patch-1
This commit is contained in:
commit
a6e7aa8f97
@ -6,15 +6,16 @@
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
|
||||
laion400m:
|
||||
config: configs/latent-diffusion/txt2img-1p4B-eval.yaml
|
||||
weights: models/ldm/text2img-large/model.ckpt
|
||||
description: Latent Diffusion LAION400M model
|
||||
width: 256
|
||||
height: 256
|
||||
stable-diffusion-1.4:
|
||||
config: configs/stable-diffusion/v1-inference.yaml
|
||||
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
||||
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||
description: Stable Diffusion inference model version 1.4
|
||||
width: 512
|
||||
height: 512
|
||||
stable-diffusion-1.5:
|
||||
config: configs/stable-diffusion/v1-inference.yaml
|
||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||
description: Stable Diffusion inference model version 1.5
|
||||
width: 512
|
||||
height: 512
|
||||
|
@ -8,7 +8,7 @@ hide:
|
||||
|
||||
## **Interactive Command Line Interface**
|
||||
|
||||
The `invoke.py` script, located in `scripts/dream.py`, provides an interactive
|
||||
The `invoke.py` script, located in `scripts/`, provides an interactive
|
||||
interface to image generation similar to the "invoke mothership" bot that Stable
|
||||
AI provided on its Discord server.
|
||||
|
||||
|
@ -55,6 +55,9 @@ torch.randint_like = fix_func(torch.randint_like)
|
||||
torch.bernoulli = fix_func(torch.bernoulli)
|
||||
torch.multinomial = fix_func(torch.multinomial)
|
||||
|
||||
# this is fallback model in case no default is defined
|
||||
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
||||
|
||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||
|
||||
Example Usage:
|
||||
@ -129,7 +132,7 @@ class Generate:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model = 'stable-diffusion-1.4',
|
||||
model = None,
|
||||
conf = 'configs/models.yaml',
|
||||
embedding_path = None,
|
||||
sampler_name = 'k_lms',
|
||||
@ -145,7 +148,6 @@ class Generate:
|
||||
free_gpu_mem=False,
|
||||
):
|
||||
mconfig = OmegaConf.load(conf)
|
||||
self.model_name = model
|
||||
self.height = None
|
||||
self.width = None
|
||||
self.model_cache = None
|
||||
@ -192,6 +194,7 @@ class Generate:
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
||||
|
||||
# for VRAM usage statistics
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||
@ -552,16 +555,19 @@ class Generate:
|
||||
from ldm.invoke.restoration.outcrop import Outcrop
|
||||
extend_instructions = {}
|
||||
for direction,pixels in _pairwise(opt.outcrop):
|
||||
extend_instructions[direction]=int(pixels)
|
||||
|
||||
restorer = Outcrop(image,self,)
|
||||
return restorer.process (
|
||||
extend_instructions,
|
||||
opt = opt,
|
||||
orig_opt = args,
|
||||
image_callback = callback,
|
||||
prefix = prefix,
|
||||
)
|
||||
try:
|
||||
extend_instructions[direction]=int(pixels)
|
||||
except ValueError:
|
||||
print(f'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"')
|
||||
if len(extend_instructions)>0:
|
||||
restorer = Outcrop(image,self,)
|
||||
return restorer.process (
|
||||
extend_instructions,
|
||||
opt = opt,
|
||||
orig_opt = args,
|
||||
image_callback = callback,
|
||||
prefix = prefix,
|
||||
)
|
||||
|
||||
elif tool == 'embiggen':
|
||||
# fetch the metadata from the image
|
||||
@ -697,8 +703,7 @@ class Generate:
|
||||
|
||||
model_data = self.model_cache.get_model(model_name)
|
||||
if model_data is None or len(model_data) == 0:
|
||||
print(f'** Model switch failed **')
|
||||
return self.model
|
||||
return None
|
||||
|
||||
self.model = model_data['model']
|
||||
self.width = model_data['width']
|
||||
|
@ -366,17 +366,16 @@ class Args(object):
|
||||
deprecated_group.add_argument('--laion400m')
|
||||
deprecated_group.add_argument('--weights') # deprecated
|
||||
model_group.add_argument(
|
||||
'--conf',
|
||||
'--config',
|
||||
'-c',
|
||||
'-conf',
|
||||
'-config',
|
||||
dest='conf',
|
||||
default='./configs/models.yaml',
|
||||
help='Path to configuration file for alternate models.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--model',
|
||||
default='stable-diffusion-1.4',
|
||||
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
|
||||
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--png_compression','-z',
|
||||
@ -529,7 +528,7 @@ class Args(object):
|
||||
formatter_class=ArgFormatter,
|
||||
description=
|
||||
"""
|
||||
*Image generation:*
|
||||
*Image generation*
|
||||
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
|
||||
|
||||
*postprocessing*
|
||||
@ -544,6 +543,13 @@ class Args(object):
|
||||
!history lists all the commands issued during the current session.
|
||||
|
||||
!NN retrieves the NNth command from the history
|
||||
|
||||
*Model manipulation*
|
||||
!models -- list models in configs/models.yaml
|
||||
!switch <model_name> -- switch to model named <model_name>
|
||||
!import_model path/to/weights/file.ckpt -- adds a model to your config
|
||||
!edit_model <model_name> -- edit a model's description
|
||||
!del_model <model_name> -- delete a model
|
||||
"""
|
||||
)
|
||||
render_group = parser.add_argument_group('General rendering')
|
||||
@ -967,17 +973,17 @@ def sha256(path):
|
||||
return sha.hexdigest()
|
||||
|
||||
def legacy_metadata_load(meta,pathname) -> Args:
|
||||
opt = Args()
|
||||
if 'Dream' in meta and len(meta['Dream']) > 0:
|
||||
dream_prompt = meta['Dream']
|
||||
opt = Args()
|
||||
opt.parse_cmd(dream_prompt)
|
||||
return opt
|
||||
else: # if nothing else, we can get the seed
|
||||
match = re.search('\d+\.(\d+)',pathname)
|
||||
if match:
|
||||
seed = match.groups()[0]
|
||||
opt = Args()
|
||||
opt.seed = seed
|
||||
return opt
|
||||
return None
|
||||
else:
|
||||
opt.prompt = ''
|
||||
opt.seed = 0
|
||||
return opt
|
||||
|
||||
|
@ -13,6 +13,7 @@ import gc
|
||||
import hashlib
|
||||
import psutil
|
||||
import transformers
|
||||
import os
|
||||
from sys import getrefcount
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.errors import ConfigAttributeError
|
||||
@ -73,7 +74,8 @@ class ModelCache(object):
|
||||
except Exception as e:
|
||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||
print(f'** restoring {self.current_model}')
|
||||
return self.get_model(self.current_model)
|
||||
self.get_model(self.current_model)
|
||||
return None
|
||||
|
||||
self.current_model = model_name
|
||||
self._push_newest_model(model_name)
|
||||
@ -84,6 +86,26 @@ class ModelCache(object):
|
||||
'hash': hash
|
||||
}
|
||||
|
||||
def default_model(self) -> str:
|
||||
'''
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
'''
|
||||
for model_name in self.config:
|
||||
if self.config[model_name].get('default',False):
|
||||
return model_name
|
||||
return None
|
||||
|
||||
def set_default_model(self,model_name:str):
|
||||
'''
|
||||
Set the default model. The change will not take
|
||||
effect until you call model_cache.commit()
|
||||
'''
|
||||
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||
for model in self.models:
|
||||
self.models[model].pop('default',None)
|
||||
self.models[model_name]['default'] = True
|
||||
|
||||
def list_models(self) -> dict:
|
||||
'''
|
||||
Return a dict of models in the format:
|
||||
@ -121,12 +143,23 @@ class ModelCache(object):
|
||||
else:
|
||||
print(line)
|
||||
|
||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
|
||||
def del_model(self, model_name:str) ->bool:
|
||||
'''
|
||||
Delete the named model.
|
||||
'''
|
||||
omega = self.config
|
||||
del omega[model_name]
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
return True
|
||||
|
||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
|
||||
'''
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory and a YAML
|
||||
string will be returned.
|
||||
On a successful update, the config will be changed in memory and the
|
||||
method will return True. Will fail with an assertion error if provided
|
||||
attributes are incorrect or the model name is missing.
|
||||
'''
|
||||
omega = self.config
|
||||
# check that all the required fields are present
|
||||
@ -139,7 +172,9 @@ class ModelCache(object):
|
||||
config[field] = model_attributes[field]
|
||||
|
||||
omega[model_name] = config
|
||||
return OmegaConf.to_yaml(omega)
|
||||
if clobber:
|
||||
self._invalidate_cached_model(model_name)
|
||||
return True
|
||||
|
||||
def _check_memory(self):
|
||||
avail_memory = psutil.virtual_memory()[1]
|
||||
@ -159,6 +194,7 @@ class ModelCache(object):
|
||||
mconfig = self.config[model_name]
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get('vae',None)
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
@ -188,9 +224,17 @@ class ModelCache(object):
|
||||
else:
|
||||
print(' | Using more accurate float32 precision')
|
||||
|
||||
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||
if vae and os.path.exists(vae):
|
||||
print(f' | Loading VAE weights from: {vae}')
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||
|
||||
model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
model.cond_stage_model.device = self.device
|
||||
|
||||
model.eval()
|
||||
|
||||
for m in model.modules():
|
||||
@ -219,6 +263,36 @@ class ModelCache(object):
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def commit(self,config_file_path:str):
|
||||
'''
|
||||
Write current configuration out to the indicated file.
|
||||
'''
|
||||
yaml_str = OmegaConf.to_yaml(self.config)
|
||||
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(self.preamble())
|
||||
outfile.write(yaml_str)
|
||||
os.rename(tmpfile,config_file_path)
|
||||
|
||||
def preamble(self):
|
||||
'''
|
||||
Returns the preamble for the config file.
|
||||
'''
|
||||
return '''# This file describes the alternative machine learning models
|
||||
# available to the dream script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
'''
|
||||
|
||||
def _invalidate_cached_model(self,model_name:str):
|
||||
self.unload_model(model_name)
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
self.models.pop(model_name,None)
|
||||
|
||||
def _model_to_cpu(self,model):
|
||||
if self.device != 'cpu':
|
||||
model.cond_stage_model.device = 'cpu'
|
||||
|
@ -57,12 +57,13 @@ COMMANDS = (
|
||||
'--png_compression','-z',
|
||||
'--text_mask','-tm',
|
||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||
'!models','!switch','!import_model','!edit_model','!del_model',
|
||||
'!mask',
|
||||
'!models','!switch','!import_model','!edit_model'
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
'!switch',
|
||||
'!edit_model',
|
||||
'!del_model',
|
||||
)
|
||||
WEIGHT_COMMANDS = (
|
||||
'!import_model',
|
||||
@ -218,9 +219,24 @@ class Completer(object):
|
||||
pydoc.pager('\n'.join(lines))
|
||||
|
||||
def set_line(self,line)->None:
|
||||
'''
|
||||
Set the default string displayed in the next line of input.
|
||||
'''
|
||||
self.linebuffer = line
|
||||
readline.redisplay()
|
||||
|
||||
def add_model(self,model_name:str)->None:
|
||||
'''
|
||||
add a model name to the completion list
|
||||
'''
|
||||
self.models.append(model_name)
|
||||
|
||||
def del_model(self,model_name:str)->None:
|
||||
'''
|
||||
removes a model name from the completion list
|
||||
'''
|
||||
self.models.remove(model_name)
|
||||
|
||||
def _seed_completions(self, text, state):
|
||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||
if m:
|
||||
|
@ -424,6 +424,15 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!del'):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print('** please provide the name of a model')
|
||||
else:
|
||||
del_config(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
elif command.startswith('!fetch'):
|
||||
file_path = command.replace('!fetch','',1).strip()
|
||||
retrieve_dream_command(opt,file_path,completer)
|
||||
@ -484,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
||||
new_config['config'] = input('Configuration file for this model: ')
|
||||
done = os.path.exists(new_config['config'])
|
||||
|
||||
done = False
|
||||
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
|
||||
while not done:
|
||||
vae = input('VAE autoencoder file for this model [None]: ')
|
||||
if os.path.exists(vae):
|
||||
new_config['vae'] = vae
|
||||
done = True
|
||||
else:
|
||||
done = len(vae)==0
|
||||
|
||||
completer.complete_extensions(None)
|
||||
|
||||
for field in ('width','height'):
|
||||
@ -498,9 +517,25 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
||||
except:
|
||||
print('** Please enter a valid integer between 64 and 2048')
|
||||
|
||||
if write_config_file(opt.conf, gen, model_name, new_config):
|
||||
gen.set_model(model_name)
|
||||
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||
|
||||
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
|
||||
completer.add_model(model_name)
|
||||
|
||||
def del_config(model_name:str, gen, opt, completer):
|
||||
current_model = gen.model_name
|
||||
if model_name == current_model:
|
||||
print("** Can't delete active model. !switch to another model first. **")
|
||||
return
|
||||
yaml_str = gen.model_cache.del_model(model_name)
|
||||
|
||||
tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(yaml_str)
|
||||
os.rename(tmpfile,opt.conf)
|
||||
print(f'** {model_name} deleted')
|
||||
completer.del_model(model_name)
|
||||
|
||||
def edit_config(model_name:str, gen, opt, completer):
|
||||
config = gen.model_cache.config
|
||||
|
||||
@ -512,33 +547,46 @@ def edit_config(model_name:str, gen, opt, completer):
|
||||
|
||||
conf = config[model_name]
|
||||
new_config = {}
|
||||
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae'))
|
||||
for field in ('description', 'weights', 'config', 'width','height'):
|
||||
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
|
||||
for field in ('description', 'weights', 'vae', 'config', 'width','height'):
|
||||
completer.linebuffer = str(conf[field]) if field in conf else ''
|
||||
new_value = input(f'{field}: ')
|
||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
||||
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||
completer.complete_extensions(None)
|
||||
|
||||
if write_config_file(opt.conf, gen, model_name, new_config, clobber=True):
|
||||
gen.set_model(model_name)
|
||||
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
|
||||
|
||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
|
||||
current_model = gen.model_name
|
||||
|
||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
||||
op = 'modify' if clobber else 'import'
|
||||
print('\n>> New configuration:')
|
||||
if make_default:
|
||||
new_config['default'] = True
|
||||
print(yaml.dump({model_name:new_config}))
|
||||
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
||||
return False
|
||||
|
||||
try:
|
||||
print('>> Verifying that new model loads...')
|
||||
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
||||
assert gen.set_model(model_name) is not None, 'model failed to load'
|
||||
except AssertionError as e:
|
||||
print(f'** configuration failed: {str(e)}')
|
||||
print(f'** aborting **')
|
||||
gen.model_cache.del_model(model_name)
|
||||
return False
|
||||
|
||||
if make_default:
|
||||
print('making this default')
|
||||
gen.model_cache.set_default_model(model_name)
|
||||
|
||||
gen.model_cache.commit(conf_path)
|
||||
|
||||
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
outfile.write(yaml_str)
|
||||
os.rename(tmpfile,conf_path)
|
||||
do_switch = input(f'Keep model loaded? [y]')
|
||||
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||
pass
|
||||
else:
|
||||
gen.set_model(current_model)
|
||||
return True
|
||||
|
||||
def do_textmask(gen, opt, callback):
|
||||
@ -598,7 +646,10 @@ def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command):
|
||||
original_file = original_file if os.path.exists(original_file) else os.path.join(opt.outdir,original_file)
|
||||
new_file = new_file if os.path.exists(new_file) else os.path.join(opt.outdir,new_file)
|
||||
meta = retrieve_metadata(original_file)['sd-metadata']
|
||||
img_data = meta['image']
|
||||
if 'image' not in meta:
|
||||
meta = metadata_dumps(opt,seeds=[opt.seed])['image']
|
||||
meta['image'] = {}
|
||||
img_data = meta.get('image')
|
||||
pp = img_data.get('postprocessing',[]) or []
|
||||
pp.append(
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user