mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into patch-1
This commit is contained in:
@ -6,15 +6,16 @@
|
|||||||
# and the width and height of the images it
|
# and the width and height of the images it
|
||||||
# was trained on.
|
# was trained on.
|
||||||
|
|
||||||
laion400m:
|
|
||||||
config: configs/latent-diffusion/txt2img-1p4B-eval.yaml
|
|
||||||
weights: models/ldm/text2img-large/model.ckpt
|
|
||||||
description: Latent Diffusion LAION400M model
|
|
||||||
width: 256
|
|
||||||
height: 256
|
|
||||||
stable-diffusion-1.4:
|
stable-diffusion-1.4:
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
|
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
description: Stable Diffusion inference model version 1.4
|
description: Stable Diffusion inference model version 1.4
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
|
stable-diffusion-1.5:
|
||||||
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
|
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||||
|
description: Stable Diffusion inference model version 1.5
|
||||||
|
width: 512
|
||||||
|
height: 512
|
||||||
|
@ -8,7 +8,7 @@ hide:
|
|||||||
|
|
||||||
## **Interactive Command Line Interface**
|
## **Interactive Command Line Interface**
|
||||||
|
|
||||||
The `invoke.py` script, located in `scripts/dream.py`, provides an interactive
|
The `invoke.py` script, located in `scripts/`, provides an interactive
|
||||||
interface to image generation similar to the "invoke mothership" bot that Stable
|
interface to image generation similar to the "invoke mothership" bot that Stable
|
||||||
AI provided on its Discord server.
|
AI provided on its Discord server.
|
||||||
|
|
||||||
|
@ -55,6 +55,9 @@ torch.randint_like = fix_func(torch.randint_like)
|
|||||||
torch.bernoulli = fix_func(torch.bernoulli)
|
torch.bernoulli = fix_func(torch.bernoulli)
|
||||||
torch.multinomial = fix_func(torch.multinomial)
|
torch.multinomial = fix_func(torch.multinomial)
|
||||||
|
|
||||||
|
# this is fallback model in case no default is defined
|
||||||
|
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
||||||
|
|
||||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
|
|
||||||
Example Usage:
|
Example Usage:
|
||||||
@ -129,7 +132,7 @@ class Generate:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model = 'stable-diffusion-1.4',
|
model = None,
|
||||||
conf = 'configs/models.yaml',
|
conf = 'configs/models.yaml',
|
||||||
embedding_path = None,
|
embedding_path = None,
|
||||||
sampler_name = 'k_lms',
|
sampler_name = 'k_lms',
|
||||||
@ -145,7 +148,6 @@ class Generate:
|
|||||||
free_gpu_mem=False,
|
free_gpu_mem=False,
|
||||||
):
|
):
|
||||||
mconfig = OmegaConf.load(conf)
|
mconfig = OmegaConf.load(conf)
|
||||||
self.model_name = model
|
|
||||||
self.height = None
|
self.height = None
|
||||||
self.width = None
|
self.width = None
|
||||||
self.model_cache = None
|
self.model_cache = None
|
||||||
@ -192,6 +194,7 @@ class Generate:
|
|||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||||
|
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
||||||
|
|
||||||
# for VRAM usage statistics
|
# for VRAM usage statistics
|
||||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||||
@ -552,8 +555,11 @@ class Generate:
|
|||||||
from ldm.invoke.restoration.outcrop import Outcrop
|
from ldm.invoke.restoration.outcrop import Outcrop
|
||||||
extend_instructions = {}
|
extend_instructions = {}
|
||||||
for direction,pixels in _pairwise(opt.outcrop):
|
for direction,pixels in _pairwise(opt.outcrop):
|
||||||
|
try:
|
||||||
extend_instructions[direction]=int(pixels)
|
extend_instructions[direction]=int(pixels)
|
||||||
|
except ValueError:
|
||||||
|
print(f'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"')
|
||||||
|
if len(extend_instructions)>0:
|
||||||
restorer = Outcrop(image,self,)
|
restorer = Outcrop(image,self,)
|
||||||
return restorer.process (
|
return restorer.process (
|
||||||
extend_instructions,
|
extend_instructions,
|
||||||
@ -697,8 +703,7 @@ class Generate:
|
|||||||
|
|
||||||
model_data = self.model_cache.get_model(model_name)
|
model_data = self.model_cache.get_model(model_name)
|
||||||
if model_data is None or len(model_data) == 0:
|
if model_data is None or len(model_data) == 0:
|
||||||
print(f'** Model switch failed **')
|
return None
|
||||||
return self.model
|
|
||||||
|
|
||||||
self.model = model_data['model']
|
self.model = model_data['model']
|
||||||
self.width = model_data['width']
|
self.width = model_data['width']
|
||||||
|
@ -366,17 +366,16 @@ class Args(object):
|
|||||||
deprecated_group.add_argument('--laion400m')
|
deprecated_group.add_argument('--laion400m')
|
||||||
deprecated_group.add_argument('--weights') # deprecated
|
deprecated_group.add_argument('--weights') # deprecated
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--conf',
|
'--config',
|
||||||
'-c',
|
'-c',
|
||||||
'-conf',
|
'-config',
|
||||||
dest='conf',
|
dest='conf',
|
||||||
default='./configs/models.yaml',
|
default='./configs/models.yaml',
|
||||||
help='Path to configuration file for alternate models.',
|
help='Path to configuration file for alternate models.',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--model',
|
'--model',
|
||||||
default='stable-diffusion-1.4',
|
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||||
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
|
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--png_compression','-z',
|
'--png_compression','-z',
|
||||||
@ -529,7 +528,7 @@ class Args(object):
|
|||||||
formatter_class=ArgFormatter,
|
formatter_class=ArgFormatter,
|
||||||
description=
|
description=
|
||||||
"""
|
"""
|
||||||
*Image generation:*
|
*Image generation*
|
||||||
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
|
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
|
||||||
|
|
||||||
*postprocessing*
|
*postprocessing*
|
||||||
@ -544,6 +543,13 @@ class Args(object):
|
|||||||
!history lists all the commands issued during the current session.
|
!history lists all the commands issued during the current session.
|
||||||
|
|
||||||
!NN retrieves the NNth command from the history
|
!NN retrieves the NNth command from the history
|
||||||
|
|
||||||
|
*Model manipulation*
|
||||||
|
!models -- list models in configs/models.yaml
|
||||||
|
!switch <model_name> -- switch to model named <model_name>
|
||||||
|
!import_model path/to/weights/file.ckpt -- adds a model to your config
|
||||||
|
!edit_model <model_name> -- edit a model's description
|
||||||
|
!del_model <model_name> -- delete a model
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
render_group = parser.add_argument_group('General rendering')
|
render_group = parser.add_argument_group('General rendering')
|
||||||
@ -967,17 +973,17 @@ def sha256(path):
|
|||||||
return sha.hexdigest()
|
return sha.hexdigest()
|
||||||
|
|
||||||
def legacy_metadata_load(meta,pathname) -> Args:
|
def legacy_metadata_load(meta,pathname) -> Args:
|
||||||
|
opt = Args()
|
||||||
if 'Dream' in meta and len(meta['Dream']) > 0:
|
if 'Dream' in meta and len(meta['Dream']) > 0:
|
||||||
dream_prompt = meta['Dream']
|
dream_prompt = meta['Dream']
|
||||||
opt = Args()
|
|
||||||
opt.parse_cmd(dream_prompt)
|
opt.parse_cmd(dream_prompt)
|
||||||
return opt
|
|
||||||
else: # if nothing else, we can get the seed
|
else: # if nothing else, we can get the seed
|
||||||
match = re.search('\d+\.(\d+)',pathname)
|
match = re.search('\d+\.(\d+)',pathname)
|
||||||
if match:
|
if match:
|
||||||
seed = match.groups()[0]
|
seed = match.groups()[0]
|
||||||
opt = Args()
|
|
||||||
opt.seed = seed
|
opt.seed = seed
|
||||||
|
else:
|
||||||
|
opt.prompt = ''
|
||||||
|
opt.seed = 0
|
||||||
return opt
|
return opt
|
||||||
return None
|
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ import gc
|
|||||||
import hashlib
|
import hashlib
|
||||||
import psutil
|
import psutil
|
||||||
import transformers
|
import transformers
|
||||||
|
import os
|
||||||
from sys import getrefcount
|
from sys import getrefcount
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.errors import ConfigAttributeError
|
from omegaconf.errors import ConfigAttributeError
|
||||||
@ -73,7 +74,8 @@ class ModelCache(object):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||||
print(f'** restoring {self.current_model}')
|
print(f'** restoring {self.current_model}')
|
||||||
return self.get_model(self.current_model)
|
self.get_model(self.current_model)
|
||||||
|
return None
|
||||||
|
|
||||||
self.current_model = model_name
|
self.current_model = model_name
|
||||||
self._push_newest_model(model_name)
|
self._push_newest_model(model_name)
|
||||||
@ -84,6 +86,26 @@ class ModelCache(object):
|
|||||||
'hash': hash
|
'hash': hash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def default_model(self) -> str:
|
||||||
|
'''
|
||||||
|
Returns the name of the default model, or None
|
||||||
|
if none is defined.
|
||||||
|
'''
|
||||||
|
for model_name in self.config:
|
||||||
|
if self.config[model_name].get('default',False):
|
||||||
|
return model_name
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_default_model(self,model_name:str):
|
||||||
|
'''
|
||||||
|
Set the default model. The change will not take
|
||||||
|
effect until you call model_cache.commit()
|
||||||
|
'''
|
||||||
|
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||||
|
for model in self.models:
|
||||||
|
self.models[model].pop('default',None)
|
||||||
|
self.models[model_name]['default'] = True
|
||||||
|
|
||||||
def list_models(self) -> dict:
|
def list_models(self) -> dict:
|
||||||
'''
|
'''
|
||||||
Return a dict of models in the format:
|
Return a dict of models in the format:
|
||||||
@ -121,12 +143,23 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
|
def del_model(self, model_name:str) ->bool:
|
||||||
|
'''
|
||||||
|
Delete the named model.
|
||||||
|
'''
|
||||||
|
omega = self.config
|
||||||
|
del omega[model_name]
|
||||||
|
if model_name in self.stack:
|
||||||
|
self.stack.remove(model_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
|
||||||
'''
|
'''
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
On a successful update, the config will be changed in memory and a YAML
|
On a successful update, the config will be changed in memory and the
|
||||||
string will be returned.
|
method will return True. Will fail with an assertion error if provided
|
||||||
|
attributes are incorrect or the model name is missing.
|
||||||
'''
|
'''
|
||||||
omega = self.config
|
omega = self.config
|
||||||
# check that all the required fields are present
|
# check that all the required fields are present
|
||||||
@ -139,7 +172,9 @@ class ModelCache(object):
|
|||||||
config[field] = model_attributes[field]
|
config[field] = model_attributes[field]
|
||||||
|
|
||||||
omega[model_name] = config
|
omega[model_name] = config
|
||||||
return OmegaConf.to_yaml(omega)
|
if clobber:
|
||||||
|
self._invalidate_cached_model(model_name)
|
||||||
|
return True
|
||||||
|
|
||||||
def _check_memory(self):
|
def _check_memory(self):
|
||||||
avail_memory = psutil.virtual_memory()[1]
|
avail_memory = psutil.virtual_memory()[1]
|
||||||
@ -159,6 +194,7 @@ class ModelCache(object):
|
|||||||
mconfig = self.config[model_name]
|
mconfig = self.config[model_name]
|
||||||
config = mconfig.config
|
config = mconfig.config
|
||||||
weights = mconfig.weights
|
weights = mconfig.weights
|
||||||
|
vae = mconfig.get('vae',None)
|
||||||
width = mconfig.width
|
width = mconfig.width
|
||||||
height = mconfig.height
|
height = mconfig.height
|
||||||
|
|
||||||
@ -188,9 +224,17 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
print(' | Using more accurate float32 precision')
|
print(' | Using more accurate float32 precision')
|
||||||
|
|
||||||
|
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||||
|
if vae and os.path.exists(vae):
|
||||||
|
print(f' | Loading VAE weights from: {vae}')
|
||||||
|
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||||
|
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||||
|
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||||
|
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||||
model.cond_stage_model.device = self.device
|
model.cond_stage_model.device = self.device
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
@ -219,6 +263,36 @@ class ModelCache(object):
|
|||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def commit(self,config_file_path:str):
|
||||||
|
'''
|
||||||
|
Write current configuration out to the indicated file.
|
||||||
|
'''
|
||||||
|
yaml_str = OmegaConf.to_yaml(self.config)
|
||||||
|
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
|
||||||
|
with open(tmpfile, 'w') as outfile:
|
||||||
|
outfile.write(self.preamble())
|
||||||
|
outfile.write(yaml_str)
|
||||||
|
os.rename(tmpfile,config_file_path)
|
||||||
|
|
||||||
|
def preamble(self):
|
||||||
|
'''
|
||||||
|
Returns the preamble for the config file.
|
||||||
|
'''
|
||||||
|
return '''# This file describes the alternative machine learning models
|
||||||
|
# available to the dream script.
|
||||||
|
#
|
||||||
|
# To add a new model, follow the examples below. Each
|
||||||
|
# model requires a model config file, a weights file,
|
||||||
|
# and the width and height of the images it
|
||||||
|
# was trained on.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def _invalidate_cached_model(self,model_name:str):
|
||||||
|
self.unload_model(model_name)
|
||||||
|
if model_name in self.stack:
|
||||||
|
self.stack.remove(model_name)
|
||||||
|
self.models.pop(model_name,None)
|
||||||
|
|
||||||
def _model_to_cpu(self,model):
|
def _model_to_cpu(self,model):
|
||||||
if self.device != 'cpu':
|
if self.device != 'cpu':
|
||||||
model.cond_stage_model.device = 'cpu'
|
model.cond_stage_model.device = 'cpu'
|
||||||
|
@ -57,12 +57,13 @@ COMMANDS = (
|
|||||||
'--png_compression','-z',
|
'--png_compression','-z',
|
||||||
'--text_mask','-tm',
|
'--text_mask','-tm',
|
||||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||||
|
'!models','!switch','!import_model','!edit_model','!del_model',
|
||||||
'!mask',
|
'!mask',
|
||||||
'!models','!switch','!import_model','!edit_model'
|
|
||||||
)
|
)
|
||||||
MODEL_COMMANDS = (
|
MODEL_COMMANDS = (
|
||||||
'!switch',
|
'!switch',
|
||||||
'!edit_model',
|
'!edit_model',
|
||||||
|
'!del_model',
|
||||||
)
|
)
|
||||||
WEIGHT_COMMANDS = (
|
WEIGHT_COMMANDS = (
|
||||||
'!import_model',
|
'!import_model',
|
||||||
@ -218,9 +219,24 @@ class Completer(object):
|
|||||||
pydoc.pager('\n'.join(lines))
|
pydoc.pager('\n'.join(lines))
|
||||||
|
|
||||||
def set_line(self,line)->None:
|
def set_line(self,line)->None:
|
||||||
|
'''
|
||||||
|
Set the default string displayed in the next line of input.
|
||||||
|
'''
|
||||||
self.linebuffer = line
|
self.linebuffer = line
|
||||||
readline.redisplay()
|
readline.redisplay()
|
||||||
|
|
||||||
|
def add_model(self,model_name:str)->None:
|
||||||
|
'''
|
||||||
|
add a model name to the completion list
|
||||||
|
'''
|
||||||
|
self.models.append(model_name)
|
||||||
|
|
||||||
|
def del_model(self,model_name:str)->None:
|
||||||
|
'''
|
||||||
|
removes a model name from the completion list
|
||||||
|
'''
|
||||||
|
self.models.remove(model_name)
|
||||||
|
|
||||||
def _seed_completions(self, text, state):
|
def _seed_completions(self, text, state):
|
||||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||||
if m:
|
if m:
|
||||||
|
@ -424,6 +424,15 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
|||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
operation = None
|
operation = None
|
||||||
|
|
||||||
|
elif command.startswith('!del'):
|
||||||
|
path = shlex.split(command)
|
||||||
|
if len(path) < 2:
|
||||||
|
print('** please provide the name of a model')
|
||||||
|
else:
|
||||||
|
del_config(path[1], gen, opt, completer)
|
||||||
|
completer.add_history(command)
|
||||||
|
operation = None
|
||||||
|
|
||||||
elif command.startswith('!fetch'):
|
elif command.startswith('!fetch'):
|
||||||
file_path = command.replace('!fetch','',1).strip()
|
file_path = command.replace('!fetch','',1).strip()
|
||||||
retrieve_dream_command(opt,file_path,completer)
|
retrieve_dream_command(opt,file_path,completer)
|
||||||
@ -484,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
|||||||
new_config['config'] = input('Configuration file for this model: ')
|
new_config['config'] = input('Configuration file for this model: ')
|
||||||
done = os.path.exists(new_config['config'])
|
done = os.path.exists(new_config['config'])
|
||||||
|
|
||||||
|
done = False
|
||||||
|
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
|
||||||
|
while not done:
|
||||||
|
vae = input('VAE autoencoder file for this model [None]: ')
|
||||||
|
if os.path.exists(vae):
|
||||||
|
new_config['vae'] = vae
|
||||||
|
done = True
|
||||||
|
else:
|
||||||
|
done = len(vae)==0
|
||||||
|
|
||||||
completer.complete_extensions(None)
|
completer.complete_extensions(None)
|
||||||
|
|
||||||
for field in ('width','height'):
|
for field in ('width','height'):
|
||||||
@ -498,8 +517,24 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
|||||||
except:
|
except:
|
||||||
print('** Please enter a valid integer between 64 and 2048')
|
print('** Please enter a valid integer between 64 and 2048')
|
||||||
|
|
||||||
if write_config_file(opt.conf, gen, model_name, new_config):
|
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||||
gen.set_model(model_name)
|
|
||||||
|
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
|
||||||
|
completer.add_model(model_name)
|
||||||
|
|
||||||
|
def del_config(model_name:str, gen, opt, completer):
|
||||||
|
current_model = gen.model_name
|
||||||
|
if model_name == current_model:
|
||||||
|
print("** Can't delete active model. !switch to another model first. **")
|
||||||
|
return
|
||||||
|
yaml_str = gen.model_cache.del_model(model_name)
|
||||||
|
|
||||||
|
tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp')
|
||||||
|
with open(tmpfile, 'w') as outfile:
|
||||||
|
outfile.write(yaml_str)
|
||||||
|
os.rename(tmpfile,opt.conf)
|
||||||
|
print(f'** {model_name} deleted')
|
||||||
|
completer.del_model(model_name)
|
||||||
|
|
||||||
def edit_config(model_name:str, gen, opt, completer):
|
def edit_config(model_name:str, gen, opt, completer):
|
||||||
config = gen.model_cache.config
|
config = gen.model_cache.config
|
||||||
@ -512,33 +547,46 @@ def edit_config(model_name:str, gen, opt, completer):
|
|||||||
|
|
||||||
conf = config[model_name]
|
conf = config[model_name]
|
||||||
new_config = {}
|
new_config = {}
|
||||||
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae'))
|
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
|
||||||
for field in ('description', 'weights', 'config', 'width','height'):
|
for field in ('description', 'weights', 'vae', 'config', 'width','height'):
|
||||||
completer.linebuffer = str(conf[field]) if field in conf else ''
|
completer.linebuffer = str(conf[field]) if field in conf else ''
|
||||||
new_value = input(f'{field}: ')
|
new_value = input(f'{field}: ')
|
||||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
||||||
|
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||||
completer.complete_extensions(None)
|
completer.complete_extensions(None)
|
||||||
|
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
|
||||||
|
|
||||||
if write_config_file(opt.conf, gen, model_name, new_config, clobber=True):
|
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
|
||||||
gen.set_model(model_name)
|
current_model = gen.model_name
|
||||||
|
|
||||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
|
||||||
op = 'modify' if clobber else 'import'
|
op = 'modify' if clobber else 'import'
|
||||||
print('\n>> New configuration:')
|
print('\n>> New configuration:')
|
||||||
|
if make_default:
|
||||||
|
new_config['default'] = True
|
||||||
print(yaml.dump({model_name:new_config}))
|
print(yaml.dump({model_name:new_config}))
|
||||||
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
print('>> Verifying that new model loads...')
|
||||||
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
||||||
|
assert gen.set_model(model_name) is not None, 'model failed to load'
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(f'** configuration failed: {str(e)}')
|
print(f'** aborting **')
|
||||||
|
gen.model_cache.del_model(model_name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
|
if make_default:
|
||||||
with open(tmpfile, 'w') as outfile:
|
print('making this default')
|
||||||
outfile.write(yaml_str)
|
gen.model_cache.set_default_model(model_name)
|
||||||
os.rename(tmpfile,conf_path)
|
|
||||||
|
gen.model_cache.commit(conf_path)
|
||||||
|
|
||||||
|
do_switch = input(f'Keep model loaded? [y]')
|
||||||
|
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
gen.set_model(current_model)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def do_textmask(gen, opt, callback):
|
def do_textmask(gen, opt, callback):
|
||||||
@ -598,7 +646,10 @@ def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command):
|
|||||||
original_file = original_file if os.path.exists(original_file) else os.path.join(opt.outdir,original_file)
|
original_file = original_file if os.path.exists(original_file) else os.path.join(opt.outdir,original_file)
|
||||||
new_file = new_file if os.path.exists(new_file) else os.path.join(opt.outdir,new_file)
|
new_file = new_file if os.path.exists(new_file) else os.path.join(opt.outdir,new_file)
|
||||||
meta = retrieve_metadata(original_file)['sd-metadata']
|
meta = retrieve_metadata(original_file)['sd-metadata']
|
||||||
img_data = meta['image']
|
if 'image' not in meta:
|
||||||
|
meta = metadata_dumps(opt,seeds=[opt.seed])['image']
|
||||||
|
meta['image'] = {}
|
||||||
|
img_data = meta.get('image')
|
||||||
pp = img_data.get('postprocessing',[]) or []
|
pp = img_data.get('postprocessing',[]) or []
|
||||||
pp.append(
|
pp.append(
|
||||||
{
|
{
|
||||||
|
Reference in New Issue
Block a user