mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
mostly back to full functionality; just missing grid generation code
This commit is contained in:
parent
b978536385
commit
0b4459b707
@ -4,6 +4,92 @@ import atexit
|
||||
import re
|
||||
from PIL import Image,PngImagePlugin
|
||||
|
||||
# -------------------image generation utils-----
|
||||
class PngWriter:
|
||||
|
||||
def __init__(self,outdir,prompt=None,batch_size=1):
|
||||
self.outdir = outdir
|
||||
self.batch_size = batch_size
|
||||
self.prompt = prompt
|
||||
self.filepath = None
|
||||
self.files_written = []
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
def write_image(self,image,seed):
|
||||
self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way
|
||||
try:
|
||||
prompt = f'{self.prompt} -S{seed}'
|
||||
self.save_image_and_prompt_to_png(image,prompt,self.filepath)
|
||||
except IOError as e:
|
||||
print(e)
|
||||
self.files_written.append([self.filepath,seed])
|
||||
|
||||
def unique_filename(self,seed,previouspath):
|
||||
revision = 1
|
||||
|
||||
if previouspath is None:
|
||||
# sort reverse alphabetically until we find max+1
|
||||
dirlist = sorted(os.listdir(self.outdir),reverse=True)
|
||||
# find the first filename that matches our pattern or return 000000.0.png
|
||||
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
|
||||
basecount = int(filename.split('.',1)[0])
|
||||
basecount += 1
|
||||
if self.batch_size > 1:
|
||||
filename = f'{basecount:06}.{seed}.01.png'
|
||||
else:
|
||||
filename = f'{basecount:06}.{seed}.png'
|
||||
return os.path.join(self.outdir,filename)
|
||||
|
||||
else:
|
||||
basename = os.path.basename(previouspath)
|
||||
x = re.match('^(\d+)\..*\.png',basename)
|
||||
if not x:
|
||||
return self.unique_filename(seed,previouspath)
|
||||
|
||||
basecount = int(x.groups()[0])
|
||||
series = 0
|
||||
finished = False
|
||||
while not finished:
|
||||
series += 1
|
||||
filename = f'{basecount:06}.{seed}.png'
|
||||
if self.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)):
|
||||
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
||||
finished = not os.path.exists(os.path.join(self.outdir,filename))
|
||||
return os.path.join(self.outdir,filename)
|
||||
|
||||
def save_image_and_prompt_to_png(self,image,prompt,path):
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text("Dream",prompt)
|
||||
image.save(path,"PNG",pnginfo=info)
|
||||
|
||||
class PromptFormatter():
|
||||
def __init__(self,t2i,opt):
|
||||
self.t2i = t2i
|
||||
self.opt = opt
|
||||
|
||||
def normalize_prompt(self):
|
||||
'''Normalize the prompt and switches'''
|
||||
t2i = self.t2i
|
||||
opt = self.opt
|
||||
|
||||
switches = list()
|
||||
switches.append(f'"{opt.prompt}"')
|
||||
switches.append(f'-s{opt.steps or t2i.steps}')
|
||||
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
|
||||
switches.append(f'-W{opt.width or t2i.width}')
|
||||
switches.append(f'-H{opt.height or t2i.height}')
|
||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||
switches.append(f'-m{t2i.sampler_name}')
|
||||
if opt.variants:
|
||||
switches.append(f'-v{opt.variants}')
|
||||
if opt.init_img:
|
||||
switches.append(f'-I{opt.init_img}')
|
||||
if opt.strength and opt.init_img is not None:
|
||||
switches.append(f'-f{opt.strength or t2i.strength}')
|
||||
if t2i.full_precision:
|
||||
switches.append('-F')
|
||||
return ' '.join(switches)
|
||||
|
||||
# ---------------readline utilities---------------------
|
||||
try:
|
||||
import readline
|
||||
@ -92,88 +178,3 @@ if readline_available:
|
||||
pass
|
||||
atexit.register(readline.write_history_file,histfile)
|
||||
|
||||
# -------------------image generation utils-----
|
||||
class PngWriter:
|
||||
|
||||
def __init__(self,outdir,opt,prompt):
|
||||
self.outdir = outdir
|
||||
self.opt = opt
|
||||
self.prompt = prompt
|
||||
self.filepath = None
|
||||
self.files_written = []
|
||||
|
||||
def write_image(self,image,seed):
|
||||
self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way
|
||||
try:
|
||||
prompt = f'{self.prompt} -S{seed}'
|
||||
self.save_image_and_prompt_to_png(image,prompt,self.filepath)
|
||||
except IOError as e:
|
||||
print(e)
|
||||
self.files_written.append([self.filepath,seed])
|
||||
|
||||
def unique_filename(self,seed,previouspath):
|
||||
revision = 1
|
||||
|
||||
if previouspath is None:
|
||||
# sort reverse alphabetically until we find max+1
|
||||
dirlist = sorted(os.listdir(self.outdir),reverse=True)
|
||||
# find the first filename that matches our pattern or return 000000.0.png
|
||||
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
|
||||
basecount = int(filename.split('.',1)[0])
|
||||
basecount += 1
|
||||
if self.opt.batch_size > 1:
|
||||
filename = f'{basecount:06}.{seed}.01.png'
|
||||
else:
|
||||
filename = f'{basecount:06}.{seed}.png'
|
||||
return os.path.join(self.outdir,filename)
|
||||
|
||||
else:
|
||||
basename = os.path.basename(previouspath)
|
||||
x = re.match('^(\d+)\..*\.png',basename)
|
||||
if not x:
|
||||
return self.unique_filename(seed,previouspath)
|
||||
|
||||
basecount = int(x.groups()[0])
|
||||
series = 0
|
||||
finished = False
|
||||
while not finished:
|
||||
series += 1
|
||||
filename = f'{basecount:06}.{seed}.png'
|
||||
if self.opt.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)):
|
||||
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
||||
finished = not os.path.exists(os.path.join(self.outdir,filename))
|
||||
return os.path.join(self.outdir,filename)
|
||||
|
||||
def save_image_and_prompt_to_png(self,image,prompt,path):
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text("Dream",prompt)
|
||||
image.save(path,"PNG",pnginfo=info)
|
||||
|
||||
class PromptFormatter():
|
||||
def __init__(self,t2i,opt):
|
||||
self.t2i = t2i
|
||||
self.opt = opt
|
||||
|
||||
def normalize_prompt(self):
|
||||
'''Normalize the prompt and switches'''
|
||||
t2i = self.t2i
|
||||
opt = self.opt
|
||||
|
||||
switches = list()
|
||||
switches.append(f'"{opt.prompt}"')
|
||||
switches.append(f'-s{opt.steps or t2i.steps}')
|
||||
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
|
||||
switches.append(f'-W{opt.width or t2i.width}')
|
||||
switches.append(f'-H{opt.height or t2i.height}')
|
||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||
switches.append(f'-m{t2i.sampler_name}')
|
||||
if opt.variants:
|
||||
switches.append(f'-v{opt.variants}')
|
||||
if opt.init_img:
|
||||
switches.append(f'-I{opt.init_img}')
|
||||
if opt.strength and opt.init_img is not None:
|
||||
switches.append(f'-f{opt.strength or t2i.strength}')
|
||||
if t2i.full_precision:
|
||||
switches.append('-F')
|
||||
return ' '.join(switches)
|
||||
|
||||
|
175
ldm/simplet2i.py
175
ldm/simplet2i.py
@ -4,52 +4,6 @@
|
||||
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
||||
|
||||
|
||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||
|
||||
Example Usage:
|
||||
|
||||
from ldm.simplet2i import T2I
|
||||
# Create an object with default values
|
||||
t2i = T2I(outdir = <path> // outputs/txt2img-samples
|
||||
model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
|
||||
config = <path> // default="configs/stable-diffusion/v1-inference.yaml
|
||||
iterations = <integer> // how many times to run the sampling (1)
|
||||
batch_size = <integer> // how many images to generate per sampling (1)
|
||||
steps = <integer> // 50
|
||||
seed = <integer> // current system time
|
||||
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
grid = <boolean> // false
|
||||
width = <integer> // image width, multiple of 64 (512)
|
||||
height = <integer> // image height, multiple of 64 (512)
|
||||
cfg_scale = <float> // unconditional guidance scale (7.5)
|
||||
)
|
||||
|
||||
# do the slow model initialization
|
||||
t2i.load_model()
|
||||
|
||||
# Do the fast inference & image generation. Any options passed here
|
||||
# override the default values assigned during class initialization
|
||||
# Will call load_model() if the model was not previously loaded.
|
||||
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
|
||||
results = t2i.txt2img(prompt = "an astronaut riding a horse"
|
||||
outdir = "./outputs/txt2img-samples)
|
||||
)
|
||||
|
||||
for row in results:
|
||||
print(f'filename={row[0]}')
|
||||
print(f'seed ={row[1]}')
|
||||
|
||||
# Same thing, but using an initial image.
|
||||
results = t2i.img2img(prompt = "an astronaut riding a horse"
|
||||
outdir = "./outputs/img2img-samples"
|
||||
init_img = "./sketches/horse+rider.png")
|
||||
|
||||
for row in results:
|
||||
print(f'filename={row[0]}')
|
||||
print(f'seed ={row[1]}')
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
@ -64,6 +18,7 @@ from torchvision.utils import make_grid
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
import transformers
|
||||
import time
|
||||
import math
|
||||
import re
|
||||
@ -73,6 +28,69 @@ from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.dream_util import PngWriter
|
||||
|
||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||
|
||||
Example Usage:
|
||||
|
||||
from ldm.simplet2i import T2I
|
||||
|
||||
# Create an object with default values
|
||||
t2i = T2I(model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
|
||||
config = <path> // configs/stable-diffusion/v1-inference.yaml
|
||||
iterations = <integer> // how many times to run the sampling (1)
|
||||
batch_size = <integer> // how many images to generate per sampling (1)
|
||||
steps = <integer> // 50
|
||||
seed = <integer> // current system time
|
||||
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
grid = <boolean> // false
|
||||
width = <integer> // image width, multiple of 64 (512)
|
||||
height = <integer> // image height, multiple of 64 (512)
|
||||
cfg_scale = <float> // unconditional guidance scale (7.5)
|
||||
)
|
||||
|
||||
# do the slow model initialization
|
||||
t2i.load_model()
|
||||
|
||||
# Do the fast inference & image generation. Any options passed here
|
||||
# override the default values assigned during class initialization
|
||||
# Will call load_model() if the model was not previously loaded and so
|
||||
# may be slow at first.
|
||||
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
|
||||
results = t2i.prompt2png(prompt = "an astronaut riding a horse",
|
||||
outdir = "./outputs/samples",
|
||||
iterations = 3)
|
||||
|
||||
for row in results:
|
||||
print(f'filename={row[0]}')
|
||||
print(f'seed ={row[1]}')
|
||||
|
||||
# Same thing, but using an initial image.
|
||||
results = t2i.prompt2png(prompt = "an astronaut riding a horse",
|
||||
outdir = "./outputs/,
|
||||
iterations = 3,
|
||||
init_img = "./sketches/horse+rider.png")
|
||||
|
||||
for row in results:
|
||||
print(f'filename={row[0]}')
|
||||
print(f'seed ={row[1]}')
|
||||
|
||||
# Same thing, but we return a series of Image objects, which lets you manipulate them,
|
||||
# combine them, and save them under arbitrary names
|
||||
|
||||
results = t2i.prompt2image(prompt = "an astronaut riding a horse"
|
||||
outdir = "./outputs/")
|
||||
for row in results:
|
||||
im = row[0]
|
||||
seed = row[1]
|
||||
im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png')
|
||||
im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg')
|
||||
|
||||
Note that the old txt2img() and img2img() calls are deprecated but will
|
||||
still work.
|
||||
"""
|
||||
|
||||
|
||||
class T2I:
|
||||
"""T2I class
|
||||
@ -141,7 +159,30 @@ The vast majority of these arguments default to reasonable values.
|
||||
self.seed = self._new_seed()
|
||||
else:
|
||||
self.seed = seed
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
def prompt2png(self,prompt,outdir,**kwargs):
|
||||
'''
|
||||
Takes a prompt and an output directory, writes out the requested number
|
||||
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
|
||||
Optional named arguments are the same as those passed to T2I and prompt2image()
|
||||
'''
|
||||
results = self.prompt2image(prompt,**kwargs)
|
||||
pngwriter = PngWriter(outdir,prompt,kwargs.get('batch_size',self.batch_size))
|
||||
for r in results:
|
||||
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG
|
||||
pngwriter.write_image(r[0],r[1])
|
||||
return pngwriter.files_written
|
||||
|
||||
def txt2img(self,prompt,**kwargs):
|
||||
outdir = kwargs.get('outdir','outputs/img-samples')
|
||||
return self.prompt2png(prompt,outdir,**kwargs)
|
||||
|
||||
def img2img(self,prompt,**kwargs):
|
||||
outdir = kwargs.get('outdir','outputs/img-samples')
|
||||
assert 'init_img' in kwargs,'call to img2img() must include the init_img argument'
|
||||
return self.prompt2png(prompt,outdir,**kwargs)
|
||||
|
||||
def prompt2image(self,
|
||||
# these are common
|
||||
prompt,
|
||||
@ -161,7 +202,34 @@ The vast majority of these arguments default to reasonable values.
|
||||
strength=None,
|
||||
variants=None,
|
||||
**args): # eat up additional cruft
|
||||
'''ldm.prompt2image() is the common entry point for txt2img() and img2img()'''
|
||||
'''
|
||||
ldm.prompt2image() is the common entry point for txt2img() and img2img()
|
||||
It takes the following arguments:
|
||||
prompt // prompt string (no default)
|
||||
iterations // iterations (1); image count=iterations x batch_size
|
||||
batch_size // images per iteration (1)
|
||||
steps // refinement steps per iteration
|
||||
seed // seed for random number generator
|
||||
width // width of image, in multiples of 64 (512)
|
||||
height // height of image, in multiples of 64 (512)
|
||||
cfg_scale // how strongly the prompt influences the image (7.5) (must be >1)
|
||||
init_img // path to an initial image - its dimensions override width and height
|
||||
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
|
||||
callback // a function or method that will be called each time an image is generated
|
||||
|
||||
To use the callback, define a function of method that receives two arguments, an Image object
|
||||
and the seed. You can then do whatever you like with the image, including converting it to
|
||||
different formats and manipulating it. For example:
|
||||
|
||||
def process_image(image,seed):
|
||||
image.save(f{'images/seed.png'})
|
||||
|
||||
The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code
|
||||
to create the requested output directory, select a unique informative name for each image, and
|
||||
write the prompt into the PNG metadata.
|
||||
'''
|
||||
steps = steps or self.steps
|
||||
seed = seed or self.seed
|
||||
width = width or self.width
|
||||
@ -175,6 +243,12 @@ The vast majority of these arguments default to reasonable values.
|
||||
model = self.load_model() # will instantiate the model or return it from cache
|
||||
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0"
|
||||
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
w = int(width/64) * 64
|
||||
h = int(height/64) * 64
|
||||
if h != height or w != width:
|
||||
print(f'Height and width must be multiples of 64. Resizing to {h}x{w}')
|
||||
height = h
|
||||
width = w
|
||||
|
||||
data = [batch_size * [prompt]]
|
||||
scope = autocast if self.precision=="autocast" else nullcontext
|
||||
@ -303,8 +377,7 @@ The vast majority of these arguments default to reasonable values.
|
||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
print(f"target t_enc is {t_enc} steps")
|
||||
|
||||
# print(f"target t_enc is {t_enc} steps")
|
||||
images = list()
|
||||
|
||||
try:
|
||||
@ -408,8 +481,8 @@ The vast majority of these arguments default to reasonable values.
|
||||
def _load_model_from_config(self, config, ckpt):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
# if "global_step" in pl_sd:
|
||||
# print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
@ -153,25 +153,58 @@ def main_loop(t2i,outdir,parser,log,infile):
|
||||
continue
|
||||
|
||||
normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt()
|
||||
variants = None
|
||||
|
||||
try:
|
||||
file_writer = PngWriter(outdir,opt,normalized_prompt)
|
||||
file_writer = PngWriter(outdir,normalized_prompt,opt.batch_size)
|
||||
callback = file_writer.write_image
|
||||
|
||||
t2i.prompt2image(image_callback=callback,
|
||||
**vars(opt))
|
||||
results = file_writer.files_written
|
||||
|
||||
if None not in (opt.variants,opt.init_img):
|
||||
variants = generate_variants(t2i,outdir,opt,results)
|
||||
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
print("Outputs:")
|
||||
write_log_message(t2i,normalized_prompt,results,log)
|
||||
if variants is not None:
|
||||
print('Variants:')
|
||||
for vr in variants:
|
||||
write_log_message(t2i,vr[0],vr[1],log)
|
||||
|
||||
print("goodbye!")
|
||||
|
||||
def generate_variants(t2i,outdir,opt,previous_gens):
|
||||
variants = []
|
||||
print(f"Generating {opt.variants} variant(s)...")
|
||||
newopt = copy.deepcopy(opt)
|
||||
newopt.iterations = 1
|
||||
newopt.variants = None
|
||||
for r in previous_gens:
|
||||
newopt.init_img = r[0]
|
||||
prompt = PromptFormatter(t2i,newopt).normalize_prompt()
|
||||
print(f"] generating variant for {newopt.init_img}")
|
||||
for j in range(0,opt.variants):
|
||||
try:
|
||||
file_writer = PngWriter(outdir,prompt,newopt.batch_size)
|
||||
callback = file_writer.write_image
|
||||
t2i.prompt2image(image_callback=callback,**vars(newopt))
|
||||
results = file_writer.files_written
|
||||
variants.append([prompt,results])
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
continue
|
||||
print(f'{opt.variants} variants generated')
|
||||
return variants
|
||||
|
||||
|
||||
def write_log_message(t2i,prompt,results,logfile):
|
||||
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
|
||||
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata'''
|
||||
last_seed = None
|
||||
img_num = 1
|
||||
seenit = {}
|
||||
|
Loading…
Reference in New Issue
Block a user