mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'model-switching' into development
This commit is contained in:
commit
b2bf2b08ff
@ -6,15 +6,15 @@
|
|||||||
# and the width and height of the images it
|
# and the width and height of the images it
|
||||||
# was trained on.
|
# was trained on.
|
||||||
|
|
||||||
laion400m:
|
|
||||||
config: configs/latent-diffusion/txt2img-1p4B-eval.yaml
|
|
||||||
weights: models/ldm/text2img-large/model.ckpt
|
|
||||||
description: Latent Diffusion LAION400M model
|
|
||||||
width: 256
|
|
||||||
height: 256
|
|
||||||
stable-diffusion-1.4:
|
stable-diffusion-1.4:
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
description: Stable Diffusion inference model version 1.4
|
description: Stable Diffusion inference model version 1.4
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
|
stable-diffusion-1.5:
|
||||||
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
|
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||||
|
description: Stable Diffusion inference model version 1.5
|
||||||
|
width: 512
|
||||||
|
height: 512
|
||||||
|
@ -55,6 +55,9 @@ torch.randint_like = fix_func(torch.randint_like)
|
|||||||
torch.bernoulli = fix_func(torch.bernoulli)
|
torch.bernoulli = fix_func(torch.bernoulli)
|
||||||
torch.multinomial = fix_func(torch.multinomial)
|
torch.multinomial = fix_func(torch.multinomial)
|
||||||
|
|
||||||
|
# this is fallback model in case no default is defined
|
||||||
|
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
||||||
|
|
||||||
def fix_func(orig):
|
def fix_func(orig):
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
def new_func(*args, **kw):
|
def new_func(*args, **kw):
|
||||||
@ -147,7 +150,7 @@ class Generate:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model = 'stable-diffusion-1.4',
|
model = None,
|
||||||
conf = 'configs/models.yaml',
|
conf = 'configs/models.yaml',
|
||||||
embedding_path = None,
|
embedding_path = None,
|
||||||
sampler_name = 'k_lms',
|
sampler_name = 'k_lms',
|
||||||
@ -163,7 +166,6 @@ class Generate:
|
|||||||
free_gpu_mem=False,
|
free_gpu_mem=False,
|
||||||
):
|
):
|
||||||
mconfig = OmegaConf.load(conf)
|
mconfig = OmegaConf.load(conf)
|
||||||
self.model_name = model
|
|
||||||
self.height = None
|
self.height = None
|
||||||
self.width = None
|
self.width = None
|
||||||
self.model_cache = None
|
self.model_cache = None
|
||||||
@ -210,6 +212,8 @@ class Generate:
|
|||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||||
|
print(f'DEBUG: model={model}, default_model={self.model_cache.default_model()}')
|
||||||
|
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
||||||
|
|
||||||
# for VRAM usage statistics
|
# for VRAM usage statistics
|
||||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||||
@ -715,8 +719,7 @@ class Generate:
|
|||||||
|
|
||||||
model_data = self.model_cache.get_model(model_name)
|
model_data = self.model_cache.get_model(model_name)
|
||||||
if model_data is None or len(model_data) == 0:
|
if model_data is None or len(model_data) == 0:
|
||||||
print(f'** Model switch failed **')
|
return None
|
||||||
return self.model
|
|
||||||
|
|
||||||
self.model = model_data['model']
|
self.model = model_data['model']
|
||||||
self.width = model_data['width']
|
self.width = model_data['width']
|
||||||
|
@ -366,17 +366,16 @@ class Args(object):
|
|||||||
deprecated_group.add_argument('--laion400m')
|
deprecated_group.add_argument('--laion400m')
|
||||||
deprecated_group.add_argument('--weights') # deprecated
|
deprecated_group.add_argument('--weights') # deprecated
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--conf',
|
'--config',
|
||||||
'-c',
|
'-c',
|
||||||
'-conf',
|
'-config',
|
||||||
dest='conf',
|
dest='conf',
|
||||||
default='./configs/models.yaml',
|
default='./configs/models.yaml',
|
||||||
help='Path to configuration file for alternate models.',
|
help='Path to configuration file for alternate models.',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--model',
|
'--model',
|
||||||
default='stable-diffusion-1.4',
|
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||||
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
|
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--png_compression','-z',
|
'--png_compression','-z',
|
||||||
@ -529,7 +528,7 @@ class Args(object):
|
|||||||
formatter_class=ArgFormatter,
|
formatter_class=ArgFormatter,
|
||||||
description=
|
description=
|
||||||
"""
|
"""
|
||||||
*Image generation:*
|
*Image generation*
|
||||||
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
|
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
|
||||||
|
|
||||||
*postprocessing*
|
*postprocessing*
|
||||||
@ -544,6 +543,13 @@ class Args(object):
|
|||||||
!history lists all the commands issued during the current session.
|
!history lists all the commands issued during the current session.
|
||||||
|
|
||||||
!NN retrieves the NNth command from the history
|
!NN retrieves the NNth command from the history
|
||||||
|
|
||||||
|
*Model manipulation*
|
||||||
|
!models -- list models in configs/models.yaml
|
||||||
|
!switch <model_name> -- switch to model named <model_name>
|
||||||
|
!import_model path/to/weights/file.ckpt -- adds a model to your config
|
||||||
|
!edit_model <model_name> -- edit a model's description
|
||||||
|
!del_model <model_name> -- delete a model
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
render_group = parser.add_argument_group('General rendering')
|
render_group = parser.add_argument_group('General rendering')
|
||||||
|
@ -73,7 +73,8 @@ class ModelCache(object):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||||
print(f'** restoring {self.current_model}')
|
print(f'** restoring {self.current_model}')
|
||||||
return self.get_model(self.current_model)
|
self.get_model(self.current_model)
|
||||||
|
return None
|
||||||
|
|
||||||
self.current_model = model_name
|
self.current_model = model_name
|
||||||
self._push_newest_model(model_name)
|
self._push_newest_model(model_name)
|
||||||
@ -84,6 +85,26 @@ class ModelCache(object):
|
|||||||
'hash': hash
|
'hash': hash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def default_model(self) -> str:
|
||||||
|
'''
|
||||||
|
Returns the name of the default model, or None
|
||||||
|
if none is defined.
|
||||||
|
'''
|
||||||
|
for model_name in self.config:
|
||||||
|
if self.config[model_name].get('default',False):
|
||||||
|
return model_name
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_default_model(self,model_name:str):
|
||||||
|
'''
|
||||||
|
Set the default model. The change will not take
|
||||||
|
effect until you call model_cache.commit()
|
||||||
|
'''
|
||||||
|
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||||
|
for model in self.models:
|
||||||
|
self.models[model].pop('default',None)
|
||||||
|
self.models[model_name]['default'] = True
|
||||||
|
|
||||||
def list_models(self) -> dict:
|
def list_models(self) -> dict:
|
||||||
'''
|
'''
|
||||||
Return a dict of models in the format:
|
Return a dict of models in the format:
|
||||||
@ -121,12 +142,23 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
|
def del_model(self, model_name:str) ->bool:
|
||||||
|
'''
|
||||||
|
Delete the named model.
|
||||||
|
'''
|
||||||
|
omega = self.config
|
||||||
|
del omega[model_name]
|
||||||
|
if model_name in self.stack:
|
||||||
|
self.stack.remove(model_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
|
||||||
'''
|
'''
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
On a successful update, the config will be changed in memory and a YAML
|
On a successful update, the config will be changed in memory and the
|
||||||
string will be returned.
|
method will return True. Will fail with an assertion error if provided
|
||||||
|
attributes are incorrect or the model name is missing.
|
||||||
'''
|
'''
|
||||||
omega = self.config
|
omega = self.config
|
||||||
# check that all the required fields are present
|
# check that all the required fields are present
|
||||||
@ -139,7 +171,9 @@ class ModelCache(object):
|
|||||||
config[field] = model_attributes[field]
|
config[field] = model_attributes[field]
|
||||||
|
|
||||||
omega[model_name] = config
|
omega[model_name] = config
|
||||||
return OmegaConf.to_yaml(omega)
|
if clobber:
|
||||||
|
self._invalidate_cached_model(model_name)
|
||||||
|
return True
|
||||||
|
|
||||||
def _check_memory(self):
|
def _check_memory(self):
|
||||||
avail_memory = psutil.virtual_memory()[1]
|
avail_memory = psutil.virtual_memory()[1]
|
||||||
@ -219,6 +253,36 @@ class ModelCache(object):
|
|||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def commit(self,config_file_path:str):
|
||||||
|
'''
|
||||||
|
Write current configuration out to the indicated file.
|
||||||
|
'''
|
||||||
|
yaml_str = OmegaConf.to_yaml(self.config)
|
||||||
|
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
|
||||||
|
with open(tmpfile, 'w') as outfile:
|
||||||
|
outfile.write(self.preamble())
|
||||||
|
outfile.write(yaml_str)
|
||||||
|
os.rename(tmpfile,config_file_path)
|
||||||
|
|
||||||
|
def preamble(self):
|
||||||
|
'''
|
||||||
|
Returns the preamble for the config file.
|
||||||
|
'''
|
||||||
|
return '''# This file describes the alternative machine learning models
|
||||||
|
# available to the dream script.
|
||||||
|
#
|
||||||
|
# To add a new model, follow the examples below. Each
|
||||||
|
# model requires a model config file, a weights file,
|
||||||
|
# and the width and height of the images it
|
||||||
|
# was trained on.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def _invalidate_cached_model(self,model_name:str):
|
||||||
|
self.unload_model(model_name)
|
||||||
|
if model_name in self.stack:
|
||||||
|
self.stack.remove(model_name)
|
||||||
|
self.models.pop(model_name,None)
|
||||||
|
|
||||||
def _model_to_cpu(self,model):
|
def _model_to_cpu(self,model):
|
||||||
if self.device != 'cpu':
|
if self.device != 'cpu':
|
||||||
model.cond_stage_model.device = 'cpu'
|
model.cond_stage_model.device = 'cpu'
|
||||||
|
@ -57,12 +57,13 @@ COMMANDS = (
|
|||||||
'--png_compression','-z',
|
'--png_compression','-z',
|
||||||
'--text_mask','-tm',
|
'--text_mask','-tm',
|
||||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||||
|
'!models','!switch','!import_model','!edit_model','!del_model',
|
||||||
'!mask',
|
'!mask',
|
||||||
'!models','!switch','!import_model','!edit_model'
|
|
||||||
)
|
)
|
||||||
MODEL_COMMANDS = (
|
MODEL_COMMANDS = (
|
||||||
'!switch',
|
'!switch',
|
||||||
'!edit_model',
|
'!edit_model',
|
||||||
|
'!del_model',
|
||||||
)
|
)
|
||||||
WEIGHT_COMMANDS = (
|
WEIGHT_COMMANDS = (
|
||||||
'!import_model',
|
'!import_model',
|
||||||
@ -218,9 +219,24 @@ class Completer(object):
|
|||||||
pydoc.pager('\n'.join(lines))
|
pydoc.pager('\n'.join(lines))
|
||||||
|
|
||||||
def set_line(self,line)->None:
|
def set_line(self,line)->None:
|
||||||
|
'''
|
||||||
|
Set the default string displayed in the next line of input.
|
||||||
|
'''
|
||||||
self.linebuffer = line
|
self.linebuffer = line
|
||||||
readline.redisplay()
|
readline.redisplay()
|
||||||
|
|
||||||
|
def add_model(self,model_name:str)->None:
|
||||||
|
'''
|
||||||
|
add a model name to the completion list
|
||||||
|
'''
|
||||||
|
self.models.append(model_name)
|
||||||
|
|
||||||
|
def del_model(self,model_name:str)->None:
|
||||||
|
'''
|
||||||
|
removes a model name from the completion list
|
||||||
|
'''
|
||||||
|
self.models.remove(model_name)
|
||||||
|
|
||||||
def _seed_completions(self, text, state):
|
def _seed_completions(self, text, state):
|
||||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||||
if m:
|
if m:
|
||||||
|
@ -424,6 +424,15 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
|||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
operation = None
|
operation = None
|
||||||
|
|
||||||
|
elif command.startswith('!del'):
|
||||||
|
path = shlex.split(command)
|
||||||
|
if len(path) < 2:
|
||||||
|
print('** please provide the name of a model')
|
||||||
|
else:
|
||||||
|
del_config(path[1], gen, opt, completer)
|
||||||
|
completer.add_history(command)
|
||||||
|
operation = None
|
||||||
|
|
||||||
elif command.startswith('!fetch'):
|
elif command.startswith('!fetch'):
|
||||||
file_path = command.replace('!fetch','',1).strip()
|
file_path = command.replace('!fetch','',1).strip()
|
||||||
retrieve_dream_command(opt,file_path,completer)
|
retrieve_dream_command(opt,file_path,completer)
|
||||||
@ -498,8 +507,24 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
|||||||
except:
|
except:
|
||||||
print('** Please enter a valid integer between 64 and 2048')
|
print('** Please enter a valid integer between 64 and 2048')
|
||||||
|
|
||||||
if write_config_file(opt.conf, gen, model_name, new_config):
|
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||||
gen.set_model(model_name)
|
|
||||||
|
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
|
||||||
|
completer.add_model(model_name)
|
||||||
|
|
||||||
|
def del_config(model_name:str, gen, opt, completer):
|
||||||
|
current_model = gen.model_name
|
||||||
|
if model_name == current_model:
|
||||||
|
print("** Can't delete active model. !switch to another model first. **")
|
||||||
|
return
|
||||||
|
yaml_str = gen.model_cache.del_model(model_name)
|
||||||
|
|
||||||
|
tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp')
|
||||||
|
with open(tmpfile, 'w') as outfile:
|
||||||
|
outfile.write(yaml_str)
|
||||||
|
os.rename(tmpfile,opt.conf)
|
||||||
|
print(f'** {model_name} deleted')
|
||||||
|
completer.del_model(model_name)
|
||||||
|
|
||||||
def edit_config(model_name:str, gen, opt, completer):
|
def edit_config(model_name:str, gen, opt, completer):
|
||||||
config = gen.model_cache.config
|
config = gen.model_cache.config
|
||||||
@ -517,28 +542,41 @@ def edit_config(model_name:str, gen, opt, completer):
|
|||||||
completer.linebuffer = str(conf[field]) if field in conf else ''
|
completer.linebuffer = str(conf[field]) if field in conf else ''
|
||||||
new_value = input(f'{field}: ')
|
new_value = input(f'{field}: ')
|
||||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
||||||
|
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||||
completer.complete_extensions(None)
|
completer.complete_extensions(None)
|
||||||
|
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
|
||||||
|
|
||||||
if write_config_file(opt.conf, gen, model_name, new_config, clobber=True):
|
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
|
||||||
gen.set_model(model_name)
|
current_model = gen.model_name
|
||||||
|
|
||||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
|
||||||
op = 'modify' if clobber else 'import'
|
op = 'modify' if clobber else 'import'
|
||||||
print('\n>> New configuration:')
|
print('\n>> New configuration:')
|
||||||
|
if make_default:
|
||||||
|
new_config['default'] = True
|
||||||
print(yaml.dump({model_name:new_config}))
|
print(yaml.dump({model_name:new_config}))
|
||||||
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
print('>> Verifying that new model loads...')
|
||||||
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
||||||
|
assert gen.set_model(model_name) is not None, 'model failed to load'
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(f'** configuration failed: {str(e)}')
|
print(f'** aborting **')
|
||||||
|
gen.model_cache.del_model(model_name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
|
if make_default:
|
||||||
with open(tmpfile, 'w') as outfile:
|
print('making this default')
|
||||||
outfile.write(yaml_str)
|
gen.model_cache.set_default_model(model_name)
|
||||||
os.rename(tmpfile,conf_path)
|
|
||||||
|
gen.model_cache.commit(conf_path)
|
||||||
|
|
||||||
|
do_switch = input(f'Keep model loaded? [y]')
|
||||||
|
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
gen.set_model(current_model)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def do_textmask(gen, opt, callback):
|
def do_textmask(gen, opt, callback):
|
||||||
|
Loading…
Reference in New Issue
Block a user