mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactoring complete; please test carefully!
This commit is contained in:
commit
19fa222810
21
README.md
21
README.md
@ -97,6 +97,27 @@ contributing this code.
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
## Reading Prompts from a File
|
||||||
|
|
||||||
|
You can automate dream.py by providing a text file with the prompts
|
||||||
|
you want to run, one line per prompt. The text file must be composed
|
||||||
|
with a text editor (e.g. Notepad) and not a word processor. Each line
|
||||||
|
should look like what you would type at the dream> prompt:
|
||||||
|
|
||||||
|
~~~~
|
||||||
|
a beautiful sunny day in the park, children playing -n4 -C10
|
||||||
|
stormy weather on a mountain top, goats grazing -s100
|
||||||
|
innovative packaging for a squid's dinner -S137038382
|
||||||
|
~~~~
|
||||||
|
|
||||||
|
Then pass this file's name to dream.py when you invoke it:
|
||||||
|
|
||||||
|
~~~~
|
||||||
|
(ldm) ~/stable-diffusion$ python3 scripts/dream.py --from_file="path/to/prompts.txt"
|
||||||
|
~~~~
|
||||||
|
|
||||||
|
>>>>>>> big-refactoring
|
||||||
|
|
||||||
## Weighted Prompts
|
## Weighted Prompts
|
||||||
|
|
||||||
You may weight different sections of the prompt to tell the sampler to attach different levels of
|
You may weight different sections of the prompt to tell the sampler to attach different levels of
|
||||||
|
1
TODO.txt
1
TODO.txt
@ -2,6 +2,7 @@ Feature requests:
|
|||||||
|
|
||||||
1. "gobig" mode - split image into strips, scale up, add detail using
|
1. "gobig" mode - split image into strips, scale up, add detail using
|
||||||
img2img and reassemble with feathering. Issue #66.
|
img2img and reassemble with feathering. Issue #66.
|
||||||
|
See https://github.com/jquesnelle/txt2imghd
|
||||||
|
|
||||||
2. Port basujindal low VRAM optimizations. Issue #62
|
2. Port basujindal low VRAM optimizations. Issue #62
|
||||||
|
|
||||||
|
195
ldm/dream_util.py
Normal file
195
ldm/dream_util.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
'''Utilities for dealing with PNG images and their path names'''
|
||||||
|
import os
|
||||||
|
import atexit
|
||||||
|
import re
|
||||||
|
from math import sqrt,floor,ceil
|
||||||
|
from PIL import Image,PngImagePlugin
|
||||||
|
|
||||||
|
# -------------------image generation utils-----
|
||||||
|
class PngWriter:
|
||||||
|
|
||||||
|
def __init__(self,outdir,prompt=None,batch_size=1):
|
||||||
|
self.outdir = outdir
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.prompt = prompt
|
||||||
|
self.filepath = None
|
||||||
|
self.files_written = []
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
|
def write_image(self,image,seed):
|
||||||
|
self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way
|
||||||
|
try:
|
||||||
|
prompt = f'{self.prompt} -S{seed}'
|
||||||
|
self.save_image_and_prompt_to_png(image,prompt,self.filepath)
|
||||||
|
except IOError as e:
|
||||||
|
print(e)
|
||||||
|
self.files_written.append([self.filepath,seed])
|
||||||
|
|
||||||
|
def unique_filename(self,seed,previouspath=None):
|
||||||
|
revision = 1
|
||||||
|
|
||||||
|
if previouspath is None:
|
||||||
|
# sort reverse alphabetically until we find max+1
|
||||||
|
dirlist = sorted(os.listdir(self.outdir),reverse=True)
|
||||||
|
# find the first filename that matches our pattern or return 000000.0.png
|
||||||
|
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
|
||||||
|
basecount = int(filename.split('.',1)[0])
|
||||||
|
basecount += 1
|
||||||
|
if self.batch_size > 1:
|
||||||
|
filename = f'{basecount:06}.{seed}.01.png'
|
||||||
|
else:
|
||||||
|
filename = f'{basecount:06}.{seed}.png'
|
||||||
|
return os.path.join(self.outdir,filename)
|
||||||
|
|
||||||
|
else:
|
||||||
|
basename = os.path.basename(previouspath)
|
||||||
|
x = re.match('^(\d+)\..*\.png',basename)
|
||||||
|
if not x:
|
||||||
|
return self.unique_filename(seed,previouspath)
|
||||||
|
|
||||||
|
basecount = int(x.groups()[0])
|
||||||
|
series = 0
|
||||||
|
finished = False
|
||||||
|
while not finished:
|
||||||
|
series += 1
|
||||||
|
filename = f'{basecount:06}.{seed}.png'
|
||||||
|
if self.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)):
|
||||||
|
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
||||||
|
finished = not os.path.exists(os.path.join(self.outdir,filename))
|
||||||
|
return os.path.join(self.outdir,filename)
|
||||||
|
|
||||||
|
def save_image_and_prompt_to_png(self,image,prompt,path):
|
||||||
|
info = PngImagePlugin.PngInfo()
|
||||||
|
info.add_text("Dream",prompt)
|
||||||
|
image.save(path,"PNG",pnginfo=info)
|
||||||
|
|
||||||
|
def make_grid(self,image_list,rows=None,cols=None):
|
||||||
|
image_cnt = len(image_list)
|
||||||
|
if None in (rows,cols):
|
||||||
|
rows = floor(sqrt(image_cnt)) # try to make it square
|
||||||
|
cols = ceil(image_cnt/rows)
|
||||||
|
width = image_list[0].width
|
||||||
|
height = image_list[0].height
|
||||||
|
|
||||||
|
grid_img = Image.new('RGB',(width*cols,height*rows))
|
||||||
|
for r in range(0,rows):
|
||||||
|
for c in range (0,cols):
|
||||||
|
i = r*rows + c
|
||||||
|
grid_img.paste(image_list[i],(c*width,r*height))
|
||||||
|
|
||||||
|
return grid_img
|
||||||
|
|
||||||
|
class PromptFormatter():
|
||||||
|
def __init__(self,t2i,opt):
|
||||||
|
self.t2i = t2i
|
||||||
|
self.opt = opt
|
||||||
|
|
||||||
|
def normalize_prompt(self):
|
||||||
|
'''Normalize the prompt and switches'''
|
||||||
|
t2i = self.t2i
|
||||||
|
opt = self.opt
|
||||||
|
|
||||||
|
switches = list()
|
||||||
|
switches.append(f'"{opt.prompt}"')
|
||||||
|
switches.append(f'-s{opt.steps or t2i.steps}')
|
||||||
|
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
|
||||||
|
switches.append(f'-W{opt.width or t2i.width}')
|
||||||
|
switches.append(f'-H{opt.height or t2i.height}')
|
||||||
|
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||||
|
switches.append(f'-m{t2i.sampler_name}')
|
||||||
|
if opt.init_img:
|
||||||
|
switches.append(f'-I{opt.init_img}')
|
||||||
|
if opt.strength and opt.init_img is not None:
|
||||||
|
switches.append(f'-f{opt.strength or t2i.strength}')
|
||||||
|
if t2i.full_precision:
|
||||||
|
switches.append('-F')
|
||||||
|
return ' '.join(switches)
|
||||||
|
|
||||||
|
# ---------------readline utilities---------------------
|
||||||
|
try:
|
||||||
|
import readline
|
||||||
|
readline_available = True
|
||||||
|
except:
|
||||||
|
readline_available = False
|
||||||
|
|
||||||
|
class Completer():
|
||||||
|
def __init__(self,options):
|
||||||
|
self.options = sorted(options)
|
||||||
|
return
|
||||||
|
|
||||||
|
def complete(self,text,state):
|
||||||
|
buffer = readline.get_line_buffer()
|
||||||
|
|
||||||
|
if text.startswith(('-I','--init_img')):
|
||||||
|
return self._path_completions(text,state,('.png'))
|
||||||
|
|
||||||
|
if buffer.strip().endswith('cd') or text.startswith(('.','/')):
|
||||||
|
return self._path_completions(text,state,())
|
||||||
|
|
||||||
|
response = None
|
||||||
|
if state == 0:
|
||||||
|
# This is the first time for this text, so build a match list.
|
||||||
|
if text:
|
||||||
|
self.matches = [s
|
||||||
|
for s in self.options
|
||||||
|
if s and s.startswith(text)]
|
||||||
|
else:
|
||||||
|
self.matches = self.options[:]
|
||||||
|
|
||||||
|
# Return the state'th item from the match list,
|
||||||
|
# if we have that many.
|
||||||
|
try:
|
||||||
|
response = self.matches[state]
|
||||||
|
except IndexError:
|
||||||
|
response = None
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _path_completions(self,text,state,extensions):
|
||||||
|
# get the path so far
|
||||||
|
if text.startswith('-I'):
|
||||||
|
path = text.replace('-I','',1).lstrip()
|
||||||
|
elif text.startswith('--init_img='):
|
||||||
|
path = text.replace('--init_img=','',1).lstrip()
|
||||||
|
else:
|
||||||
|
path = text
|
||||||
|
|
||||||
|
matches = list()
|
||||||
|
|
||||||
|
path = os.path.expanduser(path)
|
||||||
|
if len(path)==0:
|
||||||
|
matches.append(text+'./')
|
||||||
|
else:
|
||||||
|
dir = os.path.dirname(path)
|
||||||
|
dir_list = os.listdir(dir)
|
||||||
|
for n in dir_list:
|
||||||
|
if n.startswith('.') and len(n)>1:
|
||||||
|
continue
|
||||||
|
full_path = os.path.join(dir,n)
|
||||||
|
if full_path.startswith(path):
|
||||||
|
if os.path.isdir(full_path):
|
||||||
|
matches.append(os.path.join(os.path.dirname(text),n)+'/')
|
||||||
|
elif n.endswith(extensions):
|
||||||
|
matches.append(os.path.join(os.path.dirname(text),n))
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = matches[state]
|
||||||
|
except IndexError:
|
||||||
|
response = None
|
||||||
|
return response
|
||||||
|
|
||||||
|
if readline_available:
|
||||||
|
readline.set_completer(Completer(['cd','pwd',
|
||||||
|
'--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
|
||||||
|
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
|
||||||
|
'--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete)
|
||||||
|
readline.set_completer_delims(" ")
|
||||||
|
readline.parse_and_bind('tab: complete')
|
||||||
|
|
||||||
|
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
|
||||||
|
try:
|
||||||
|
readline.read_history_file(histfile)
|
||||||
|
readline.set_history_length(1000)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
atexit.register(readline.write_history_file,histfile)
|
||||||
|
|
490
ldm/simplet2i.py
490
ldm/simplet2i.py
@ -4,53 +4,6 @@
|
|||||||
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||||
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
||||||
|
|
||||||
|
|
||||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
|
||||||
|
|
||||||
Example Usage:
|
|
||||||
|
|
||||||
from ldm.simplet2i import T2I
|
|
||||||
# Create an object with default values
|
|
||||||
t2i = T2I(outdir = <path> // outputs/txt2img-samples
|
|
||||||
model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
|
|
||||||
config = <path> // default="configs/stable-diffusion/v1-inference.yaml
|
|
||||||
iterations = <integer> // how many times to run the sampling (1)
|
|
||||||
batch_size = <integer> // how many images to generate per sampling (1)
|
|
||||||
steps = <integer> // 50
|
|
||||||
seed = <integer> // current system time
|
|
||||||
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
|
||||||
grid = <boolean> // false
|
|
||||||
width = <integer> // image width, multiple of 64 (512)
|
|
||||||
height = <integer> // image height, multiple of 64 (512)
|
|
||||||
cfg_scale = <float> // unconditional guidance scale (7.5)
|
|
||||||
fixed_code = <boolean> // False
|
|
||||||
)
|
|
||||||
|
|
||||||
# do the slow model initialization
|
|
||||||
t2i.load_model()
|
|
||||||
|
|
||||||
# Do the fast inference & image generation. Any options passed here
|
|
||||||
# override the default values assigned during class initialization
|
|
||||||
# Will call load_model() if the model was not previously loaded.
|
|
||||||
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
|
|
||||||
results = t2i.txt2img(prompt = "an astronaut riding a horse"
|
|
||||||
outdir = "./outputs/txt2img-samples)
|
|
||||||
)
|
|
||||||
|
|
||||||
for row in results:
|
|
||||||
print(f'filename={row[0]}')
|
|
||||||
print(f'seed ={row[1]}')
|
|
||||||
|
|
||||||
# Same thing, but using an initial image.
|
|
||||||
results = t2i.img2img(prompt = "an astronaut riding a horse"
|
|
||||||
outdir = "./outputs/img2img-samples"
|
|
||||||
init_img = "./sketches/horse+rider.png")
|
|
||||||
|
|
||||||
for row in results:
|
|
||||||
print(f'filename={row[0]}')
|
|
||||||
print(f'seed ={row[1]}')
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
@ -65,8 +18,8 @@ from torchvision.utils import make_grid
|
|||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
|
import transformers
|
||||||
import time
|
import time
|
||||||
import math
|
|
||||||
import re
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
@ -74,12 +27,74 @@ from ldm.util import instantiate_from_config
|
|||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
from ldm.models.diffusion.ksampler import KSampler
|
from ldm.models.diffusion.ksampler import KSampler
|
||||||
|
from ldm.dream_util import PngWriter
|
||||||
|
|
||||||
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
|
|
||||||
|
Example Usage:
|
||||||
|
|
||||||
|
from ldm.simplet2i import T2I
|
||||||
|
|
||||||
|
# Create an object with default values
|
||||||
|
t2i = T2I(model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
|
config = <path> // configs/stable-diffusion/v1-inference.yaml
|
||||||
|
iterations = <integer> // how many times to run the sampling (1)
|
||||||
|
batch_size = <integer> // how many images to generate per sampling (1)
|
||||||
|
steps = <integer> // 50
|
||||||
|
seed = <integer> // current system time
|
||||||
|
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||||
|
grid = <boolean> // false
|
||||||
|
width = <integer> // image width, multiple of 64 (512)
|
||||||
|
height = <integer> // image height, multiple of 64 (512)
|
||||||
|
cfg_scale = <float> // unconditional guidance scale (7.5)
|
||||||
|
)
|
||||||
|
|
||||||
|
# do the slow model initialization
|
||||||
|
t2i.load_model()
|
||||||
|
|
||||||
|
# Do the fast inference & image generation. Any options passed here
|
||||||
|
# override the default values assigned during class initialization
|
||||||
|
# Will call load_model() if the model was not previously loaded and so
|
||||||
|
# may be slow at first.
|
||||||
|
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
|
||||||
|
results = t2i.prompt2png(prompt = "an astronaut riding a horse",
|
||||||
|
outdir = "./outputs/samples",
|
||||||
|
iterations = 3)
|
||||||
|
|
||||||
|
for row in results:
|
||||||
|
print(f'filename={row[0]}')
|
||||||
|
print(f'seed ={row[1]}')
|
||||||
|
|
||||||
|
# Same thing, but using an initial image.
|
||||||
|
results = t2i.prompt2png(prompt = "an astronaut riding a horse",
|
||||||
|
outdir = "./outputs/,
|
||||||
|
iterations = 3,
|
||||||
|
init_img = "./sketches/horse+rider.png")
|
||||||
|
|
||||||
|
for row in results:
|
||||||
|
print(f'filename={row[0]}')
|
||||||
|
print(f'seed ={row[1]}')
|
||||||
|
|
||||||
|
# Same thing, but we return a series of Image objects, which lets you manipulate them,
|
||||||
|
# combine them, and save them under arbitrary names
|
||||||
|
|
||||||
|
results = t2i.prompt2image(prompt = "an astronaut riding a horse"
|
||||||
|
outdir = "./outputs/")
|
||||||
|
for row in results:
|
||||||
|
im = row[0]
|
||||||
|
seed = row[1]
|
||||||
|
im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png')
|
||||||
|
im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg')
|
||||||
|
|
||||||
|
Note that the old txt2img() and img2img() calls are deprecated but will
|
||||||
|
still work.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class T2I:
|
class T2I:
|
||||||
"""T2I class
|
"""T2I class
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
outdir
|
|
||||||
model
|
model
|
||||||
config
|
config
|
||||||
iterations
|
iterations
|
||||||
@ -87,12 +102,9 @@ class T2I:
|
|||||||
steps
|
steps
|
||||||
seed
|
seed
|
||||||
sampler_name
|
sampler_name
|
||||||
grid
|
|
||||||
individual
|
|
||||||
width
|
width
|
||||||
height
|
height
|
||||||
cfg_scale
|
cfg_scale
|
||||||
fixed_code
|
|
||||||
latent_channels
|
latent_channels
|
||||||
downsampling_factor
|
downsampling_factor
|
||||||
precision
|
precision
|
||||||
@ -102,23 +114,19 @@ class T2I:
|
|||||||
The vast majority of these arguments default to reasonable values.
|
The vast majority of these arguments default to reasonable values.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
outdir="outputs/txt2img-samples",
|
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
iterations = 1,
|
iterations = 1,
|
||||||
width=512,
|
|
||||||
height=512,
|
|
||||||
grid=False,
|
|
||||||
individual=None, # redundant
|
|
||||||
steps=50,
|
steps=50,
|
||||||
seed=None,
|
seed=None,
|
||||||
cfg_scale=7.5,
|
cfg_scale=7.5,
|
||||||
weights="models/ldm/stable-diffusion-v1/model.ckpt",
|
weights="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||||
config = "configs/stable-diffusion/v1-inference.yaml",
|
config = "configs/stable-diffusion/v1-inference.yaml",
|
||||||
|
width=512,
|
||||||
|
height=512,
|
||||||
sampler_name="klms",
|
sampler_name="klms",
|
||||||
latent_channels=4,
|
latent_channels=4,
|
||||||
downsampling_factor=8,
|
downsampling_factor=8,
|
||||||
ddim_eta=0.0, # deterministic
|
ddim_eta=0.0, # deterministic
|
||||||
fixed_code=False,
|
|
||||||
precision='autocast',
|
precision='autocast',
|
||||||
full_precision=False,
|
full_precision=False,
|
||||||
strength=0.75, # default in scripts/img2img.py
|
strength=0.75, # default in scripts/img2img.py
|
||||||
@ -126,18 +134,15 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt
|
latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt
|
||||||
device='cuda'
|
device='cuda'
|
||||||
):
|
):
|
||||||
self.outdir = outdir
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.iterations = iterations
|
self.iterations = iterations
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
self.grid = grid
|
|
||||||
self.steps = steps
|
self.steps = steps
|
||||||
self.cfg_scale = cfg_scale
|
self.cfg_scale = cfg_scale
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.config = config
|
self.config = config
|
||||||
self.sampler_name = sampler_name
|
self.sampler_name = sampler_name
|
||||||
self.fixed_code = fixed_code
|
|
||||||
self.latent_channels = latent_channels
|
self.latent_channels = latent_channels
|
||||||
self.downsampling_factor = downsampling_factor
|
self.downsampling_factor = downsampling_factor
|
||||||
self.ddim_eta = ddim_eta
|
self.ddim_eta = ddim_eta
|
||||||
@ -153,17 +158,77 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
self.seed = self._new_seed()
|
self.seed = self._new_seed()
|
||||||
else:
|
else:
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
@torch.no_grad()
|
def prompt2png(self,prompt,outdir,**kwargs):
|
||||||
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
|
'''
|
||||||
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
Takes a prompt and an output directory, writes out the requested number
|
||||||
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,init_img=None,
|
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
|
||||||
skip_normalize=False,variants=None): # note the "variants" option is an unused hack caused by how options are passed
|
Optional named arguments are the same as those passed to T2I and prompt2image()
|
||||||
"""
|
'''
|
||||||
Generate an image from the prompt, writing iteration images into the outdir
|
results = self.prompt2image(prompt,**kwargs)
|
||||||
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
|
pngwriter = PngWriter(outdir,prompt,kwargs.get('batch_size',self.batch_size))
|
||||||
"""
|
for r in results:
|
||||||
outdir = outdir or self.outdir
|
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG
|
||||||
|
pngwriter.write_image(r[0],r[1])
|
||||||
|
return pngwriter.files_written
|
||||||
|
|
||||||
|
def txt2img(self,prompt,**kwargs):
|
||||||
|
outdir = kwargs.get('outdir','outputs/img-samples')
|
||||||
|
return self.prompt2png(prompt,outdir,**kwargs)
|
||||||
|
|
||||||
|
def img2img(self,prompt,**kwargs):
|
||||||
|
outdir = kwargs.get('outdir','outputs/img-samples')
|
||||||
|
assert 'init_img' in kwargs,'call to img2img() must include the init_img argument'
|
||||||
|
return self.prompt2png(prompt,outdir,**kwargs)
|
||||||
|
|
||||||
|
def prompt2image(self,
|
||||||
|
# these are common
|
||||||
|
prompt,
|
||||||
|
batch_size=None,
|
||||||
|
iterations=None,
|
||||||
|
steps=None,
|
||||||
|
seed=None,
|
||||||
|
cfg_scale=None,
|
||||||
|
ddim_eta=None,
|
||||||
|
skip_normalize=False,
|
||||||
|
image_callback=None,
|
||||||
|
# these are specific to txt2img
|
||||||
|
width=None,
|
||||||
|
height=None,
|
||||||
|
# these are specific to img2img
|
||||||
|
init_img=None,
|
||||||
|
strength=None,
|
||||||
|
variants=None,
|
||||||
|
**args): # eat up additional cruft
|
||||||
|
'''
|
||||||
|
ldm.prompt2image() is the common entry point for txt2img() and img2img()
|
||||||
|
It takes the following arguments:
|
||||||
|
prompt // prompt string (no default)
|
||||||
|
iterations // iterations (1); image count=iterations x batch_size
|
||||||
|
batch_size // images per iteration (1)
|
||||||
|
steps // refinement steps per iteration
|
||||||
|
seed // seed for random number generator
|
||||||
|
width // width of image, in multiples of 64 (512)
|
||||||
|
height // height of image, in multiples of 64 (512)
|
||||||
|
cfg_scale // how strongly the prompt influences the image (7.5) (must be >1)
|
||||||
|
init_img // path to an initial image - its dimensions override width and height
|
||||||
|
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||||
|
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||||
|
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
|
||||||
|
callback // a function or method that will be called each time an image is generated
|
||||||
|
|
||||||
|
To use the callback, define a function of method that receives two arguments, an Image object
|
||||||
|
and the seed. You can then do whatever you like with the image, including converting it to
|
||||||
|
different formats and manipulating it. For example:
|
||||||
|
|
||||||
|
def process_image(image,seed):
|
||||||
|
image.save(f{'images/seed.png'})
|
||||||
|
|
||||||
|
The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code
|
||||||
|
to create the requested output directory, select a unique informative name for each image, and
|
||||||
|
write the prompt into the PNG metadata.
|
||||||
|
'''
|
||||||
steps = steps or self.steps
|
steps = steps or self.steps
|
||||||
seed = seed or self.seed
|
seed = seed or self.seed
|
||||||
width = width or self.width
|
width = width or self.width
|
||||||
@ -172,52 +237,70 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
ddim_eta = ddim_eta or self.ddim_eta
|
ddim_eta = ddim_eta or self.ddim_eta
|
||||||
batch_size = batch_size or self.batch_size
|
batch_size = batch_size or self.batch_size
|
||||||
iterations = iterations or self.iterations
|
iterations = iterations or self.iterations
|
||||||
strength = strength or self.strength # not actually used here, but preserved for code refactoring
|
strength = strength or self.strength
|
||||||
embedding_path = embedding_path or self.embedding_path
|
|
||||||
|
|
||||||
model = self.load_model() # will instantiate the model or return it from cache
|
model = self.load_model() # will instantiate the model or return it from cache
|
||||||
|
|
||||||
assert strength<1.0 and strength>=0.0, "strength (-f) must be >=0.0 and <1.0"
|
|
||||||
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0"
|
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0"
|
||||||
|
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
|
w = int(width/64) * 64
|
||||||
|
h = int(height/64) * 64
|
||||||
|
if h != height or w != width:
|
||||||
|
print(f'Height and width must be multiples of 64. Resizing to {h}x{w}')
|
||||||
|
height = h
|
||||||
|
width = w
|
||||||
|
|
||||||
# grid and individual are mutually exclusive, with individual taking priority.
|
|
||||||
# not necessary, but needed for compatability with dream bot
|
|
||||||
if (grid is None):
|
|
||||||
grid = self.grid
|
|
||||||
if individual:
|
|
||||||
grid = False
|
|
||||||
|
|
||||||
data = [batch_size * [prompt]]
|
data = [batch_size * [prompt]]
|
||||||
|
scope = autocast if self.precision=="autocast" else nullcontext
|
||||||
|
|
||||||
# make directories and establish names for the output files
|
tic = time.time()
|
||||||
os.makedirs(outdir, exist_ok=True)
|
if init_img:
|
||||||
|
assert os.path.exists(init_img),f'{init_img}: File not found'
|
||||||
|
results = self._img2img(prompt,
|
||||||
|
data=data,precision_scope=scope,
|
||||||
|
batch_size=batch_size,iterations=iterations,
|
||||||
|
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
|
||||||
|
skip_normalize=skip_normalize,
|
||||||
|
init_img=init_img,strength=strength,variants=variants,
|
||||||
|
callback=image_callback)
|
||||||
|
else:
|
||||||
|
results = self._txt2img(prompt,
|
||||||
|
data=data,precision_scope=scope,
|
||||||
|
batch_size=batch_size,iterations=iterations,
|
||||||
|
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
|
||||||
|
skip_normalize=skip_normalize,
|
||||||
|
width=width,height=height,
|
||||||
|
callback=image_callback)
|
||||||
|
toc = time.time()
|
||||||
|
print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic))
|
||||||
|
return results
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _txt2img(self,prompt,
|
||||||
|
data,precision_scope,
|
||||||
|
batch_size,iterations,
|
||||||
|
steps,seed,cfg_scale,ddim_eta,
|
||||||
|
skip_normalize,
|
||||||
|
width,height,
|
||||||
|
callback): # the callback is called each time a new Image is generated
|
||||||
|
"""
|
||||||
|
Generate an image from the prompt, writing iteration images into the outdir
|
||||||
|
The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...]
|
||||||
|
"""
|
||||||
|
|
||||||
start_code = None
|
|
||||||
if self.fixed_code:
|
|
||||||
start_code = torch.randn([batch_size,
|
|
||||||
self.latent_channels,
|
|
||||||
height // self.downsampling_factor,
|
|
||||||
width // self.downsampling_factor],
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
precision_scope = autocast if self.precision=="autocast" else nullcontext
|
|
||||||
sampler = self.sampler
|
sampler = self.sampler
|
||||||
images = list()
|
images = list()
|
||||||
seeds = list()
|
|
||||||
filename = None
|
|
||||||
image_count = 0
|
image_count = 0
|
||||||
tic = time.time()
|
|
||||||
|
|
||||||
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
|
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
|
||||||
try:
|
try:
|
||||||
with precision_scope(self.device.type), model.ema_scope():
|
with precision_scope(self.device.type), self.model.ema_scope():
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
for n in trange(iterations, desc="Sampling"):
|
for n in trange(iterations, desc="Sampling"):
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||||
uc = None
|
uc = None
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
uc = self.model.get_learned_conditioning(batch_size * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
|
|
||||||
@ -233,138 +316,78 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
weight = weights[i]
|
weight = weights[i]
|
||||||
if not skip_normalize:
|
if not skip_normalize:
|
||||||
weight = weight / totalWeight
|
weight = weight / totalWeight
|
||||||
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||||
else: # just standard 1 prompt
|
else: # just standard 1 prompt
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = self.model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||||
samples_ddim, _ = sampler.sample(S=steps,
|
samples_ddim, _ = sampler.sample(S=steps,
|
||||||
conditioning=c,
|
conditioning=c,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
unconditional_guidance_scale=cfg_scale,
|
unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
eta=ddim_eta,
|
eta=ddim_eta)
|
||||||
x_T=start_code)
|
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
for x_sample in x_samples_ddim:
|
||||||
if not grid:
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
for x_sample in x_samples_ddim:
|
image = Image.fromarray(x_sample.astype(np.uint8))
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
images.append([image,seed])
|
||||||
filename = self._unique_filename(outdir,previousname=filename,
|
if callback is not None:
|
||||||
seed=seed,isbatch=(batch_size>1))
|
callback(image,seed)
|
||||||
assert not os.path.exists(filename)
|
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
|
|
||||||
images.append([filename,seed])
|
|
||||||
else:
|
|
||||||
all_samples.append(x_samples_ddim)
|
|
||||||
seeds.append(seed)
|
|
||||||
|
|
||||||
image_count += 1
|
|
||||||
seed = self._new_seed()
|
seed = self._new_seed()
|
||||||
if grid:
|
|
||||||
images = self._make_grid(samples=all_samples,
|
|
||||||
seeds=seeds,
|
|
||||||
batch_size=batch_size,
|
|
||||||
iterations=iterations,
|
|
||||||
outdir=outdir)
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print('*interrupted*')
|
print('*interrupted*')
|
||||||
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
|
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(str(e))
|
print(str(e))
|
||||||
|
|
||||||
toc = time.time()
|
|
||||||
print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
|
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
# There is lots of shared code between this and txt2img and should be refactored.
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
|
def _img2img(self,prompt,
|
||||||
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
data,precision_scope,
|
||||||
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,
|
batch_size,iterations,
|
||||||
skip_normalize=False,variants=None): # note the "variants" option is an unused hack caused by how options are passed
|
steps,seed,cfg_scale,ddim_eta,
|
||||||
|
skip_normalize,
|
||||||
|
init_img,strength,variants,
|
||||||
|
callback):
|
||||||
"""
|
"""
|
||||||
Generate an image from the prompt and the initial image, writing iteration images into the outdir
|
Generate an image from the prompt and the initial image, writing iteration images into the outdir
|
||||||
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
|
The output is a list of lists in the format: [[image,seed1], [image,seed2],...]
|
||||||
"""
|
"""
|
||||||
outdir = outdir or self.outdir
|
|
||||||
steps = steps or self.steps
|
|
||||||
seed = seed or self.seed
|
|
||||||
cfg_scale = cfg_scale or self.cfg_scale
|
|
||||||
ddim_eta = ddim_eta or self.ddim_eta
|
|
||||||
batch_size = batch_size or self.batch_size
|
|
||||||
iterations = iterations or self.iterations
|
|
||||||
strength = strength or self.strength
|
|
||||||
embedding_path = embedding_path or self.embedding_path
|
|
||||||
|
|
||||||
assert strength<1.0 and strength>=0.0, "strength (-f) must be >=0.0 and <1.0"
|
|
||||||
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0"
|
|
||||||
|
|
||||||
if init_img is None:
|
|
||||||
print("no init_img provided!")
|
|
||||||
return []
|
|
||||||
|
|
||||||
model = self.load_model() # will instantiate the model or return it from cache
|
|
||||||
|
|
||||||
precision_scope = autocast if self.precision=="autocast" else nullcontext
|
|
||||||
|
|
||||||
# grid and individual are mutually exclusive, with individual taking priority.
|
|
||||||
# not necessary, but needed for compatability with dream bot
|
|
||||||
if (grid is None):
|
|
||||||
grid = self.grid
|
|
||||||
if individual:
|
|
||||||
grid = False
|
|
||||||
|
|
||||||
data = [batch_size * [prompt]]
|
|
||||||
|
|
||||||
# PLMS sampler not supported yet, so ignore previous sampler
|
# PLMS sampler not supported yet, so ignore previous sampler
|
||||||
if self.sampler_name!='ddim':
|
if self.sampler_name!='ddim':
|
||||||
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
|
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
|
||||||
sampler = DDIMSampler(model, device=self.device)
|
sampler = DDIMSampler(self.model, device=self.device)
|
||||||
else:
|
else:
|
||||||
sampler = self.sampler
|
sampler = self.sampler
|
||||||
|
|
||||||
# make directories and establish names for the output files
|
|
||||||
os.makedirs(outdir, exist_ok=True)
|
|
||||||
|
|
||||||
assert os.path.isfile(init_img)
|
|
||||||
init_image = self._load_img(init_img).to(self.device)
|
init_image = self._load_img(init_img).to(self.device)
|
||||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||||
with precision_scope(self.device.type):
|
with precision_scope(self.device.type):
|
||||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||||
|
|
||||||
try:
|
|
||||||
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
|
||||||
except AssertionError:
|
|
||||||
print(f"strength must be between 0.0 and 1.0, but received value {strength}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
t_enc = int(strength * steps)
|
t_enc = int(strength * steps)
|
||||||
print(f"target t_enc is {t_enc} steps")
|
# print(f"target t_enc is {t_enc} steps")
|
||||||
|
|
||||||
images = list()
|
images = list()
|
||||||
seeds = list()
|
|
||||||
filename = None
|
|
||||||
image_count = 0 # actual number of iterations performed
|
|
||||||
tic = time.time()
|
|
||||||
|
|
||||||
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
|
|
||||||
try:
|
try:
|
||||||
with precision_scope(self.device.type), model.ema_scope():
|
with precision_scope(self.device.type), self.model.ema_scope():
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
for n in trange(iterations, desc="Sampling"):
|
for n in trange(iterations, desc="Sampling"):
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||||
uc = None
|
uc = None
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
uc = self.model.get_learned_conditioning(batch_size * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
|
|
||||||
@ -380,9 +403,9 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
weight = weights[i]
|
weight = weights[i]
|
||||||
if not skip_normalize:
|
if not skip_normalize:
|
||||||
weight = weight / totalWeight
|
weight = weight / totalWeight
|
||||||
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||||
else: # just standard 1 prompt
|
else: # just standard 1 prompt
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = self.model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
||||||
@ -390,28 +413,16 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,)
|
unconditional_conditioning=uc,)
|
||||||
|
|
||||||
x_samples = model.decode_first_stage(samples)
|
x_samples = self.model.decode_first_stage(samples)
|
||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
if not grid:
|
for x_sample in x_samples:
|
||||||
for x_sample in x_samples:
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
image = Image.fromarray(x_sample.astype(np.uint8))
|
||||||
filename = self._unique_filename(outdir,previousname=filename,
|
images.append([image,seed])
|
||||||
seed=seed,isbatch=(batch_size>1))
|
if callback is not None:
|
||||||
assert not os.path.exists(filename)
|
callback(image,seed)
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
|
|
||||||
images.append([filename,seed])
|
|
||||||
else:
|
|
||||||
all_samples.append(x_samples)
|
|
||||||
seeds.append(seed)
|
|
||||||
image_count +=1
|
|
||||||
seed = self._new_seed()
|
seed = self._new_seed()
|
||||||
if grid:
|
|
||||||
images = self._make_grid(samples=all_samples,
|
|
||||||
seeds=seeds,
|
|
||||||
batch_size=batch_size,
|
|
||||||
iterations=iterations,
|
|
||||||
outdir=outdir)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print('*interrupted*')
|
print('*interrupted*')
|
||||||
@ -419,26 +430,6 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion")
|
print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
toc = time.time()
|
|
||||||
print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
def _make_grid(self,samples,seeds,batch_size,iterations,outdir):
|
|
||||||
images = list()
|
|
||||||
n_rows = batch_size if batch_size>1 else int(math.sqrt(batch_size * iterations))
|
|
||||||
# save as grid
|
|
||||||
grid = torch.stack(samples, 0)
|
|
||||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
|
||||||
grid = make_grid(grid, nrow=n_rows)
|
|
||||||
|
|
||||||
# to image
|
|
||||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
|
||||||
filename = self._unique_filename(outdir,seed=seeds[0],grid_count=batch_size*iterations)
|
|
||||||
Image.fromarray(grid.astype(np.uint8)).save(filename)
|
|
||||||
for s in seeds:
|
|
||||||
images.append([filename,s])
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def _new_seed(self):
|
def _new_seed(self):
|
||||||
@ -489,8 +480,8 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
def _load_model_from_config(self, config, ckpt):
|
def _load_model_from_config(self, config, ckpt):
|
||||||
print(f"Loading model from {ckpt}")
|
print(f"Loading model from {ckpt}")
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
if "global_step" in pl_sd:
|
# if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
# print(f"Global Step: {pl_sd['global_step']}")
|
||||||
sd = pl_sd["state_dict"]
|
sd = pl_sd["state_dict"]
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
@ -514,43 +505,6 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
return 2.*image - 1.
|
return 2.*image - 1.
|
||||||
|
|
||||||
def _unique_filename(self,outdir,previousname=None,seed=0,isbatch=False,grid_count=None):
|
|
||||||
revision = 1
|
|
||||||
|
|
||||||
if previousname is None:
|
|
||||||
# sort reverse alphabetically until we find max+1
|
|
||||||
dirlist = sorted(os.listdir(outdir),reverse=True)
|
|
||||||
# find the first filename that matches our pattern or return 000000.0.png
|
|
||||||
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
|
|
||||||
basecount = int(filename.split('.',1)[0])
|
|
||||||
basecount += 1
|
|
||||||
if grid_count is not None:
|
|
||||||
grid_label = f'grid#1-{grid_count}'
|
|
||||||
filename = f'{basecount:06}.{seed}.{grid_label}.png'
|
|
||||||
elif isbatch:
|
|
||||||
filename = f'{basecount:06}.{seed}.01.png'
|
|
||||||
else:
|
|
||||||
filename = f'{basecount:06}.{seed}.png'
|
|
||||||
|
|
||||||
return os.path.join(outdir,filename)
|
|
||||||
|
|
||||||
else:
|
|
||||||
previousname = os.path.basename(previousname)
|
|
||||||
x = re.match('^(\d+)\..*\.png',previousname)
|
|
||||||
if not x:
|
|
||||||
return self._unique_filename(outdir,previousname,seed)
|
|
||||||
|
|
||||||
basecount = int(x.groups()[0])
|
|
||||||
series = 0
|
|
||||||
finished = False
|
|
||||||
while not finished:
|
|
||||||
series += 1
|
|
||||||
filename = f'{basecount:06}.{seed}.png'
|
|
||||||
if isbatch or os.path.exists(os.path.join(outdir,filename)):
|
|
||||||
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
|
||||||
finished = not os.path.exists(os.path.join(outdir,filename))
|
|
||||||
return os.path.join(outdir,filename)
|
|
||||||
|
|
||||||
def _split_weighted_subprompts(text):
|
def _split_weighted_subprompts(text):
|
||||||
"""
|
"""
|
||||||
grabs all text up to the first occurrence of ':'
|
grabs all text up to the first occurrence of ':'
|
||||||
|
251
scripts/dream.py
251
scripts/dream.py
@ -3,18 +3,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import shlex
|
import shlex
|
||||||
import atexit
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import copy
|
import copy
|
||||||
from PIL import Image,PngImagePlugin
|
from ldm.dream_util import Completer,PngWriter,PromptFormatter
|
||||||
|
|
||||||
# readline unavailable on windows systems
|
|
||||||
try:
|
|
||||||
import readline
|
|
||||||
readline_available = True
|
|
||||||
except:
|
|
||||||
readline_available = False
|
|
||||||
|
|
||||||
debugging = False
|
debugging = False
|
||||||
|
|
||||||
@ -35,10 +27,6 @@ def main():
|
|||||||
config = "configs/stable-diffusion/v1-inference.yaml"
|
config = "configs/stable-diffusion/v1-inference.yaml"
|
||||||
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
|
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||||
|
|
||||||
# command line history will be stored in a file called "~/.dream_history"
|
|
||||||
if readline_available:
|
|
||||||
setup_readline()
|
|
||||||
|
|
||||||
print("* Initializing, be patient...\n")
|
print("* Initializing, be patient...\n")
|
||||||
sys.path.append('.')
|
sys.path.append('.')
|
||||||
from pytorch_lightning import logging
|
from pytorch_lightning import logging
|
||||||
@ -54,8 +42,6 @@ def main():
|
|||||||
# the user input loop
|
# the user input loop
|
||||||
t2i = T2I(width=width,
|
t2i = T2I(width=width,
|
||||||
height=height,
|
height=height,
|
||||||
batch_size=opt.batch_size,
|
|
||||||
outdir=opt.outdir,
|
|
||||||
sampler_name=opt.sampler_name,
|
sampler_name=opt.sampler_name,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
full_precision=opt.full_precision,
|
full_precision=opt.full_precision,
|
||||||
@ -87,13 +73,13 @@ def main():
|
|||||||
log_path = os.path.join(opt.outdir,'dream_log.txt')
|
log_path = os.path.join(opt.outdir,'dream_log.txt')
|
||||||
with open(log_path,'a') as log:
|
with open(log_path,'a') as log:
|
||||||
cmd_parser = create_cmd_parser()
|
cmd_parser = create_cmd_parser()
|
||||||
main_loop(t2i,cmd_parser,log,infile)
|
main_loop(t2i,opt.outdir,cmd_parser,log,infile)
|
||||||
log.close()
|
log.close()
|
||||||
if infile:
|
if infile:
|
||||||
infile.close()
|
infile.close()
|
||||||
|
|
||||||
|
|
||||||
def main_loop(t2i,parser,log,infile):
|
def main_loop(t2i,outdir,parser,log,infile):
|
||||||
''' prompt/read/execute loop '''
|
''' prompt/read/execute loop '''
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
@ -131,13 +117,13 @@ def main_loop(t2i,parser,log,infile):
|
|||||||
if elements[0]=='cd' and len(elements)>1:
|
if elements[0]=='cd' and len(elements)>1:
|
||||||
if os.path.exists(elements[1]):
|
if os.path.exists(elements[1]):
|
||||||
print(f"setting image output directory to {elements[1]}")
|
print(f"setting image output directory to {elements[1]}")
|
||||||
t2i.outdir=elements[1]
|
outdir=elements[1]
|
||||||
else:
|
else:
|
||||||
print(f"directory {elements[1]} does not exist")
|
print(f"directory {elements[1]} does not exist")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if elements[0]=='pwd':
|
if elements[0]=='pwd':
|
||||||
print(f"current output directory is {t2i.outdir}")
|
print(f"current output directory is {outdir}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
|
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
|
||||||
@ -166,117 +152,77 @@ def main_loop(t2i,parser,log,infile):
|
|||||||
print("Try again with a prompt!")
|
print("Try again with a prompt!")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt()
|
||||||
|
individual_images = not opt.grid
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if opt.init_img is None:
|
file_writer = PngWriter(outdir,normalized_prompt,opt.batch_size)
|
||||||
results = t2i.txt2img(**vars(opt))
|
callback = file_writer.write_image if individual_images else None
|
||||||
else:
|
|
||||||
assert os.path.exists(opt.init_img),f"No file found at {opt.init_img}. On Linux systems, pressing <tab> after -I will autocomplete a list of possible image files."
|
image_list = t2i.prompt2image(image_callback=callback,**vars(opt))
|
||||||
if None not in (opt.width,opt.height):
|
results = file_writer.files_written if individual_images else image_list
|
||||||
print('Warning: width and height options are ignored when modifying an init image')
|
|
||||||
results = t2i.img2img(**vars(opt))
|
if opt.grid and len(results) > 0:
|
||||||
|
grid_img = file_writer.make_grid([r[0] for r in results])
|
||||||
|
filename = file_writer.unique_filename(results[0][1])
|
||||||
|
seeds = [a[1] for a in results]
|
||||||
|
results = [[filename,seeds]]
|
||||||
|
metadata_prompt = f'{normalized_prompt} -S{results[0][1]}'
|
||||||
|
file_writer.save_image_and_prompt_to_png(grid_img,metadata_prompt,filename)
|
||||||
|
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
except OSError as e:
|
||||||
allVariantResults = []
|
print(e)
|
||||||
if opt.variants is not None:
|
continue
|
||||||
print(f"Generating {opt.variants} variant(s)...")
|
|
||||||
newopt = copy.deepcopy(opt)
|
|
||||||
newopt.iterations = 1
|
|
||||||
newopt.variants = None
|
|
||||||
for r in results:
|
|
||||||
newopt.init_img = r[0]
|
|
||||||
print(f"\t generating variant for {newopt.init_img}")
|
|
||||||
for j in range(0, opt.variants):
|
|
||||||
try:
|
|
||||||
variantResults = t2i.img2img(**vars(newopt))
|
|
||||||
allVariantResults.append([newopt,variantResults])
|
|
||||||
except AssertionError as e:
|
|
||||||
print(e)
|
|
||||||
continue
|
|
||||||
print(f"{opt.variants} Variants generated!")
|
|
||||||
|
|
||||||
print("Outputs:")
|
print("Outputs:")
|
||||||
write_log_message(t2i,opt,results,log)
|
write_log_message(t2i,normalized_prompt,results,log)
|
||||||
|
|
||||||
if allVariantResults:
|
|
||||||
print("Variant outputs:")
|
|
||||||
for vr in allVariantResults:
|
|
||||||
write_log_message(t2i,vr[0],vr[1],log)
|
|
||||||
|
|
||||||
|
|
||||||
print("goodbye!")
|
print("goodbye!")
|
||||||
|
|
||||||
|
# variant generation is going to be superseded by a generalized
|
||||||
|
# "prompt-morph" functionality
|
||||||
|
# def generate_variants(t2i,outdir,opt,previous_gens):
|
||||||
|
# variants = []
|
||||||
|
# print(f"Generating {opt.variants} variant(s)...")
|
||||||
|
# newopt = copy.deepcopy(opt)
|
||||||
|
# newopt.iterations = 1
|
||||||
|
# newopt.variants = None
|
||||||
|
# for r in previous_gens:
|
||||||
|
# newopt.init_img = r[0]
|
||||||
|
# prompt = PromptFormatter(t2i,newopt).normalize_prompt()
|
||||||
|
# print(f"] generating variant for {newopt.init_img}")
|
||||||
|
# for j in range(0,opt.variants):
|
||||||
|
# try:
|
||||||
|
# file_writer = PngWriter(outdir,prompt,newopt.batch_size)
|
||||||
|
# callback = file_writer.write_image
|
||||||
|
# t2i.prompt2image(image_callback=callback,**vars(newopt))
|
||||||
|
# results = file_writer.files_written
|
||||||
|
# variants.append([prompt,results])
|
||||||
|
# except AssertionError as e:
|
||||||
|
# print(e)
|
||||||
|
# continue
|
||||||
|
# print(f'{opt.variants} variants generated')
|
||||||
|
# return variants
|
||||||
|
|
||||||
|
|
||||||
def write_log_message(t2i,opt,results,logfile):
|
def write_log_message(t2i,prompt,results,logfile):
|
||||||
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
|
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata'''
|
||||||
switches = _reconstruct_switches(t2i,opt)
|
|
||||||
prompt_str = ' '.join(switches)
|
|
||||||
|
|
||||||
# when multiple images are produced in batch, then we keep track of where each starts
|
|
||||||
last_seed = None
|
last_seed = None
|
||||||
img_num = 1
|
img_num = 1
|
||||||
batch_size = opt.batch_size or t2i.batch_size
|
|
||||||
seenit = {}
|
seenit = {}
|
||||||
|
|
||||||
seeds = [a[1] for a in results]
|
|
||||||
if batch_size > 1:
|
|
||||||
seeds = f"(seeds for each batch row: {seeds})"
|
|
||||||
else:
|
|
||||||
seeds = f"(seeds for individual images: {seeds})"
|
|
||||||
|
|
||||||
for r in results:
|
for r in results:
|
||||||
seed = r[1]
|
seed = r[1]
|
||||||
log_message = (f'{r[0]}: {prompt_str} -S{seed}')
|
log_message = (f'{r[0]}: {prompt} -S{seed}')
|
||||||
|
|
||||||
if batch_size > 1:
|
|
||||||
if seed != last_seed:
|
|
||||||
img_num = 1
|
|
||||||
log_message += f' # (batch image {img_num} of {batch_size})'
|
|
||||||
else:
|
|
||||||
img_num += 1
|
|
||||||
log_message += f' # (batch image {img_num} of {batch_size})'
|
|
||||||
last_seed = seed
|
|
||||||
print(log_message)
|
print(log_message)
|
||||||
logfile.write(log_message+"\n")
|
logfile.write(log_message+"\n")
|
||||||
logfile.flush()
|
logfile.flush()
|
||||||
if r[0] not in seenit:
|
|
||||||
seenit[r[0]] = True
|
|
||||||
try:
|
|
||||||
if opt.grid:
|
|
||||||
_write_prompt_to_png(r[0],f'{prompt_str} -g -S{seed} {seeds}')
|
|
||||||
else:
|
|
||||||
_write_prompt_to_png(r[0],f'{prompt_str} -S{seed}')
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Could not open file '{r[0]}' for reading")
|
|
||||||
|
|
||||||
def _reconstruct_switches(t2i,opt):
|
|
||||||
'''Normalize the prompt and switches'''
|
|
||||||
switches = list()
|
|
||||||
switches.append(f'"{opt.prompt}"')
|
|
||||||
switches.append(f'-s{opt.steps or t2i.steps}')
|
|
||||||
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
|
|
||||||
switches.append(f'-W{opt.width or t2i.width}')
|
|
||||||
switches.append(f'-H{opt.height or t2i.height}')
|
|
||||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
|
||||||
switches.append(f'-m{t2i.sampler_name}')
|
|
||||||
if opt.variants:
|
|
||||||
switches.append(f'-v{opt.variants}')
|
|
||||||
if opt.init_img:
|
|
||||||
switches.append(f'-I{opt.init_img}')
|
|
||||||
if opt.strength and opt.init_img is not None:
|
|
||||||
switches.append(f'-f{opt.strength or t2i.strength}')
|
|
||||||
if t2i.full_precision:
|
|
||||||
switches.append('-F')
|
|
||||||
return switches
|
|
||||||
|
|
||||||
def _write_prompt_to_png(path,prompt):
|
|
||||||
info = PngImagePlugin.PngInfo()
|
|
||||||
info.add_text("Dream",prompt)
|
|
||||||
im = Image.open(path)
|
|
||||||
im.save(path,"PNG",pnginfo=info)
|
|
||||||
|
|
||||||
def create_argv_parser():
|
def create_argv_parser():
|
||||||
parser = argparse.ArgumentParser(description="Parse script's command line args")
|
parser = argparse.ArgumentParser(description="Parse script's command line args")
|
||||||
parser.add_argument("--laion400m",
|
parser.add_argument("--laion400m",
|
||||||
@ -297,10 +243,6 @@ def create_argv_parser():
|
|||||||
dest='full_precision',
|
dest='full_precision',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="use slower full precision math for calculations")
|
help="use slower full precision math for calculations")
|
||||||
parser.add_argument('-b','--batch_size',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
|
|
||||||
parser.add_argument('--sampler','-m',
|
parser.add_argument('--sampler','-m',
|
||||||
dest="sampler_name",
|
dest="sampler_name",
|
||||||
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
|
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
|
||||||
@ -336,93 +278,12 @@ def create_cmd_parser():
|
|||||||
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
|
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
|
||||||
parser.add_argument('-I','--init_img',type=str,help="path to input image for img2img mode (supersedes width and height)")
|
parser.add_argument('-I','--init_img',type=str,help="path to input image for img2img mode (supersedes width and height)")
|
||||||
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
|
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
|
||||||
parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
|
# variants is going to be superseded by a generalized "prompt-morph" function
|
||||||
|
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
|
||||||
parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization")
|
parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
if readline_available:
|
|
||||||
def setup_readline():
|
|
||||||
readline.set_completer(Completer(['cd','pwd',
|
|
||||||
'--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
|
|
||||||
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
|
|
||||||
'--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete)
|
|
||||||
readline.set_completer_delims(" ")
|
|
||||||
readline.parse_and_bind('tab: complete')
|
|
||||||
load_history()
|
|
||||||
|
|
||||||
def load_history():
|
|
||||||
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
|
|
||||||
try:
|
|
||||||
readline.read_history_file(histfile)
|
|
||||||
readline.set_history_length(1000)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
atexit.register(readline.write_history_file,histfile)
|
|
||||||
|
|
||||||
class Completer():
|
|
||||||
def __init__(self,options):
|
|
||||||
self.options = sorted(options)
|
|
||||||
return
|
|
||||||
|
|
||||||
def complete(self,text,state):
|
|
||||||
buffer = readline.get_line_buffer()
|
|
||||||
|
|
||||||
if text.startswith(('-I','--init_img')):
|
|
||||||
return self._path_completions(text,state,('.png'))
|
|
||||||
|
|
||||||
if buffer.strip().endswith('cd') or text.startswith(('.','/')):
|
|
||||||
return self._path_completions(text,state,())
|
|
||||||
|
|
||||||
response = None
|
|
||||||
if state == 0:
|
|
||||||
# This is the first time for this text, so build a match list.
|
|
||||||
if text:
|
|
||||||
self.matches = [s
|
|
||||||
for s in self.options
|
|
||||||
if s and s.startswith(text)]
|
|
||||||
else:
|
|
||||||
self.matches = self.options[:]
|
|
||||||
|
|
||||||
# Return the state'th item from the match list,
|
|
||||||
# if we have that many.
|
|
||||||
try:
|
|
||||||
response = self.matches[state]
|
|
||||||
except IndexError:
|
|
||||||
response = None
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _path_completions(self,text,state,extensions):
|
|
||||||
# get the path so far
|
|
||||||
if text.startswith('-I'):
|
|
||||||
path = text.replace('-I','',1).lstrip()
|
|
||||||
elif text.startswith('--init_img='):
|
|
||||||
path = text.replace('--init_img=','',1).lstrip()
|
|
||||||
else:
|
|
||||||
path = text
|
|
||||||
|
|
||||||
matches = list()
|
|
||||||
|
|
||||||
path = os.path.expanduser(path)
|
|
||||||
if len(path)==0:
|
|
||||||
matches.append(text+'./')
|
|
||||||
else:
|
|
||||||
dir = os.path.dirname(path)
|
|
||||||
dir_list = os.listdir(dir)
|
|
||||||
for n in dir_list:
|
|
||||||
if n.startswith('.') and len(n)>1:
|
|
||||||
continue
|
|
||||||
full_path = os.path.join(dir,n)
|
|
||||||
if full_path.startswith(path):
|
|
||||||
if os.path.isdir(full_path):
|
|
||||||
matches.append(os.path.join(os.path.dirname(text),n)+'/')
|
|
||||||
elif n.endswith(extensions):
|
|
||||||
matches.append(os.path.join(os.path.dirname(text),n))
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = matches[state]
|
|
||||||
except IndexError:
|
|
||||||
response = None
|
|
||||||
return response
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user