mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' of https://github.com/BaristaLabs/stable-diffusion-dream into add-simple-variant-mechanism
This commit is contained in:
@ -1,4 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
import atexit
|
||||
@ -58,7 +60,9 @@ def main():
|
||||
weights=weights,
|
||||
full_precision=opt.full_precision,
|
||||
config=config,
|
||||
latent_diffusion_weights=opt.laion400m # this is solely for recreating the prompt
|
||||
latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt
|
||||
embedding_path=opt.embedding_path,
|
||||
device=opt.device
|
||||
)
|
||||
|
||||
# make sure the output directory exists
|
||||
@ -108,6 +112,10 @@ def main_loop(t2i,parser,log,infile):
|
||||
if command.startswith(('#','//')):
|
||||
continue
|
||||
|
||||
# before splitting, escape single quotes so as not to mess
|
||||
# up the parser
|
||||
command = command.replace("'","\\'")
|
||||
|
||||
try:
|
||||
elements = shlex.split(command)
|
||||
except ValueError as e:
|
||||
@ -159,10 +167,16 @@ def main_loop(t2i,parser,log,infile):
|
||||
print("Try again with a prompt!")
|
||||
continue
|
||||
|
||||
if opt.init_img is None:
|
||||
results = t2i.txt2img(**vars(opt))
|
||||
else:
|
||||
results = t2i.img2img(**vars(opt))
|
||||
try:
|
||||
if opt.init_img is None:
|
||||
results = t2i.txt2img(**vars(opt))
|
||||
else:
|
||||
assert os.path.exists(opt.init_img),f"No file found at {opt.init_img}. On Linux systems, pressing <tab> after -I will autocomplete a list of possible image files."
|
||||
results = t2i.img2img(**vars(opt))
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
|
||||
allVariantResults = []
|
||||
if opt.variants is not None:
|
||||
@ -175,13 +189,17 @@ def main_loop(t2i,parser,log,infile):
|
||||
for j in range(0, opt.variants):
|
||||
newopt.init_img = resultPath
|
||||
print(f"{newopt.init_img}")
|
||||
variantResults = t2i.img2img(**vars(newopt))
|
||||
try:
|
||||
variantResults = t2i.img2img(**vars(newopt))
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
continue
|
||||
allVariantResults.append([newopt,variantResults])
|
||||
print(f"{opt.variants} Variants generated!")
|
||||
|
||||
print("Outputs:")
|
||||
write_log_message(t2i,opt,results,log)
|
||||
|
||||
|
||||
if len(allVariantResults)>0:
|
||||
print("Variant outputs:")
|
||||
for vr in allVariantResults:
|
||||
@ -242,6 +260,7 @@ def _reconstruct_switches(t2i,opt):
|
||||
switches.append(f'-W{opt.width or t2i.width}')
|
||||
switches.append(f'-H{opt.height or t2i.height}')
|
||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||
switches.append(f'-m{t2i.sampler_name}')
|
||||
if opt.init_img:
|
||||
switches.append(f'-I{opt.init_img}')
|
||||
if opt.strength and opt.init_img is not None:
|
||||
@ -282,14 +301,22 @@ def create_argv_parser():
|
||||
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
|
||||
parser.add_argument('--sampler','-m',
|
||||
dest="sampler_name",
|
||||
choices=['plms','ddim', 'klms'],
|
||||
default='klms',
|
||||
help="which sampler to use (klms) - can only be set on command line")
|
||||
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
|
||||
default='k_lms',
|
||||
help="which sampler to use (k_lms) - can only be set on command line")
|
||||
parser.add_argument('--outdir',
|
||||
'-o',
|
||||
type=str,
|
||||
default="outputs/img-samples",
|
||||
help="directory in which to place generated images and a log of prompts and seeds")
|
||||
parser.add_argument('--embedding_path',
|
||||
type=str,
|
||||
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line")
|
||||
parser.add_argument('--device',
|
||||
'-d',
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if avalible")
|
||||
return parser
|
||||
|
||||
|
||||
@ -397,3 +424,4 @@ if readline_available:
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
83
scripts/merge_embeddings.py
Normal file
83
scripts/merge_embeddings.py
Normal file
@ -0,0 +1,83 @@
|
||||
from ldm.modules.encoders.modules import BERTTokenizer
|
||||
from ldm.modules.embedding_manager import EmbeddingManager
|
||||
|
||||
import argparse, os
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
def get_placeholder_loop(placeholder_string, tokenizer):
|
||||
|
||||
new_placeholder = None
|
||||
|
||||
while True:
|
||||
if new_placeholder is None:
|
||||
new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ")
|
||||
else:
|
||||
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
|
||||
|
||||
token = tokenizer(new_placeholder)
|
||||
|
||||
if torch.count_nonzero(token) == 3:
|
||||
return new_placeholder, token[0, 1]
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--manager_ckpts",
|
||||
type=str,
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="Paths to a set of embedding managers to be merged."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output path for the merged manager",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = BERTTokenizer(vq_interface=False, max_length=77)
|
||||
EmbeddingManager = partial(EmbeddingManager, tokenizer, ["*"])
|
||||
|
||||
string_to_token_dict = {}
|
||||
string_to_param_dict = torch.nn.ParameterDict()
|
||||
|
||||
placeholder_to_src = {}
|
||||
|
||||
for manager_ckpt in args.manager_ckpts:
|
||||
print(f"Parsing {manager_ckpt}...")
|
||||
|
||||
manager = EmbeddingManager()
|
||||
manager.load(manager_ckpt)
|
||||
|
||||
for placeholder_string in manager.string_to_token_dict:
|
||||
if not placeholder_string in string_to_token_dict:
|
||||
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
|
||||
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]
|
||||
|
||||
placeholder_to_src[placeholder_string] = manager_ckpt
|
||||
else:
|
||||
new_placeholder, new_token = get_placeholder_loop(placeholder_string, tokenizer)
|
||||
string_to_token_dict[new_placeholder] = new_token
|
||||
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
|
||||
|
||||
placeholder_to_src[new_placeholder] = manager_ckpt
|
||||
|
||||
print("Saving combined manager...")
|
||||
merged_manager = EmbeddingManager()
|
||||
merged_manager.string_to_param_dict = string_to_param_dict
|
||||
merged_manager.string_to_token_dict = string_to_token_dict
|
||||
merged_manager.save(args.output_path)
|
||||
|
||||
print("Managers merged. Final list of placeholders: ")
|
||||
print(placeholder_to_src)
|
||||
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
# Before running stable-diffusion on an internet-isolated machine,
|
||||
# run this script from one with internet connectivity. The
|
||||
# two machines must share a common .cache directory.
|
||||
@ -30,3 +31,4 @@ tokenizer =CLIPTokenizer.from_pretrained(version)
|
||||
transformer=CLIPTextModel.from_pretrained(version)
|
||||
print('\n\n...success')
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user