2022-08-24 13:22:04 +00:00
|
|
|
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
|
|
|
|
|
|
|
# Derived from source code carrying the following copyrights
|
|
|
|
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
|
|
|
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
|
|
|
|
2022-08-25 04:42:37 +00:00
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
import random
|
|
|
|
import os
|
2022-09-01 21:54:01 +00:00
|
|
|
import traceback
|
2022-08-25 04:42:37 +00:00
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from PIL import Image
|
|
|
|
from tqdm import tqdm, trange
|
|
|
|
from itertools import islice
|
|
|
|
from einops import rearrange, repeat
|
2022-09-03 06:13:31 +00:00
|
|
|
from torch import nn
|
2022-08-25 04:42:37 +00:00
|
|
|
from torchvision.utils import make_grid
|
|
|
|
from pytorch_lightning import seed_everything
|
|
|
|
from torch import autocast
|
|
|
|
from contextlib import contextmanager, nullcontext
|
|
|
|
import transformers
|
|
|
|
import time
|
|
|
|
import re
|
2022-08-29 03:13:23 +00:00
|
|
|
import sys
|
2022-08-25 04:42:37 +00:00
|
|
|
|
2022-09-01 20:16:46 +00:00
|
|
|
from ldm.util import instantiate_from_config
|
|
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
|
|
from ldm.models.diffusion.plms import PLMSSampler
|
2022-08-25 04:42:37 +00:00
|
|
|
from ldm.models.diffusion.ksampler import KSampler
|
2022-09-01 20:16:46 +00:00
|
|
|
from ldm.dream.pngwriter import PngWriter
|
|
|
|
from ldm.dream.image_util import InitImageResizer
|
2022-09-01 05:21:14 +00:00
|
|
|
from ldm.dream.devices import choose_autocast_device, choose_torch_device
|
2022-08-24 13:22:04 +00:00
|
|
|
|
2022-08-17 01:34:37 +00:00
|
|
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
|
|
|
|
|
|
|
Example Usage:
|
|
|
|
|
|
|
|
from ldm.simplet2i import T2I
|
2022-08-25 04:42:37 +00:00
|
|
|
|
2022-08-17 01:34:37 +00:00
|
|
|
# Create an object with default values
|
2022-08-25 04:42:37 +00:00
|
|
|
t2i = T2I(model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
|
|
|
|
config = <path> // configs/stable-diffusion/v1-inference.yaml
|
2022-08-17 16:35:49 +00:00
|
|
|
iterations = <integer> // how many times to run the sampling (1)
|
2022-08-17 01:34:37 +00:00
|
|
|
steps = <integer> // 50
|
|
|
|
seed = <integer> // current system time
|
2022-08-23 21:25:39 +00:00
|
|
|
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
2022-08-17 01:34:37 +00:00
|
|
|
grid = <boolean> // false
|
|
|
|
width = <integer> // image width, multiple of 64 (512)
|
|
|
|
height = <integer> // image height, multiple of 64 (512)
|
|
|
|
cfg_scale = <float> // unconditional guidance scale (7.5)
|
|
|
|
)
|
2022-08-17 16:35:49 +00:00
|
|
|
|
2022-08-17 01:34:37 +00:00
|
|
|
# do the slow model initialization
|
|
|
|
t2i.load_model()
|
|
|
|
|
2022-08-25 22:13:07 +00:00
|
|
|
# Do the fast inference & image generation. Any options passed here
|
2022-08-17 01:34:37 +00:00
|
|
|
# override the default values assigned during class initialization
|
2022-08-25 04:42:37 +00:00
|
|
|
# Will call load_model() if the model was not previously loaded and so
|
|
|
|
# may be slow at first.
|
2022-08-17 16:35:49 +00:00
|
|
|
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
|
2022-08-25 04:42:37 +00:00
|
|
|
results = t2i.prompt2png(prompt = "an astronaut riding a horse",
|
|
|
|
outdir = "./outputs/samples",
|
|
|
|
iterations = 3)
|
2022-08-17 16:35:49 +00:00
|
|
|
|
|
|
|
for row in results:
|
|
|
|
print(f'filename={row[0]}')
|
|
|
|
print(f'seed ={row[1]}')
|
2022-08-18 14:47:53 +00:00
|
|
|
|
|
|
|
# Same thing, but using an initial image.
|
2022-08-25 04:42:37 +00:00
|
|
|
results = t2i.prompt2png(prompt = "an astronaut riding a horse",
|
|
|
|
outdir = "./outputs/,
|
|
|
|
iterations = 3,
|
|
|
|
init_img = "./sketches/horse+rider.png")
|
2022-08-25 22:13:07 +00:00
|
|
|
|
2022-08-18 14:47:53 +00:00
|
|
|
for row in results:
|
|
|
|
print(f'filename={row[0]}')
|
|
|
|
print(f'seed ={row[1]}')
|
2022-08-17 16:35:49 +00:00
|
|
|
|
2022-08-25 04:42:37 +00:00
|
|
|
# Same thing, but we return a series of Image objects, which lets you manipulate them,
|
|
|
|
# combine them, and save them under arbitrary names
|
|
|
|
|
|
|
|
results = t2i.prompt2image(prompt = "an astronaut riding a horse"
|
|
|
|
outdir = "./outputs/")
|
|
|
|
for row in results:
|
|
|
|
im = row[0]
|
|
|
|
seed = row[1]
|
|
|
|
im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png')
|
|
|
|
im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg')
|
|
|
|
|
|
|
|
Note that the old txt2img() and img2img() calls are deprecated but will
|
|
|
|
still work.
|
|
|
|
"""
|
2022-08-17 01:34:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
class T2I:
|
|
|
|
"""T2I class
|
2022-08-26 07:15:42 +00:00
|
|
|
Attributes
|
|
|
|
----------
|
|
|
|
model
|
|
|
|
config
|
|
|
|
iterations
|
|
|
|
steps
|
|
|
|
seed
|
|
|
|
sampler_name
|
|
|
|
width
|
|
|
|
height
|
|
|
|
cfg_scale
|
|
|
|
latent_channels
|
|
|
|
downsampling_factor
|
|
|
|
precision
|
|
|
|
strength
|
2022-09-03 06:13:31 +00:00
|
|
|
seamless
|
2022-08-26 07:15:42 +00:00
|
|
|
embedding_path
|
|
|
|
|
|
|
|
The vast majority of these arguments default to reasonable values.
|
2022-08-17 01:34:37 +00:00
|
|
|
"""
|
2022-08-26 07:15:42 +00:00
|
|
|
|
|
|
|
def __init__(
|
2022-09-01 21:54:01 +00:00
|
|
|
self,
|
|
|
|
iterations=1,
|
|
|
|
steps=50,
|
|
|
|
seed=None,
|
|
|
|
cfg_scale=7.5,
|
|
|
|
weights='models/ldm/stable-diffusion-v1/model.ckpt',
|
|
|
|
config='configs/stable-diffusion/v1-inference.yaml',
|
|
|
|
grid=False,
|
|
|
|
width=512,
|
|
|
|
height=512,
|
|
|
|
sampler_name='k_lms',
|
|
|
|
latent_channels=4,
|
|
|
|
downsampling_factor=8,
|
|
|
|
ddim_eta=0.0, # deterministic
|
|
|
|
precision='autocast',
|
|
|
|
full_precision=False,
|
|
|
|
strength=0.75, # default in scripts/img2img.py
|
2022-09-03 06:13:31 +00:00
|
|
|
seamless=False,
|
2022-09-01 21:54:01 +00:00
|
|
|
embedding_path=None,
|
|
|
|
device_type = 'cuda',
|
|
|
|
# just to keep track of this parameter when regenerating prompt
|
|
|
|
# needs to be replaced when new configuration system implemented.
|
|
|
|
latent_diffusion_weights=False,
|
2022-08-17 01:34:37 +00:00
|
|
|
):
|
2022-09-01 14:16:05 +00:00
|
|
|
self.iterations = iterations
|
|
|
|
self.width = width
|
|
|
|
self.height = height
|
|
|
|
self.steps = steps
|
|
|
|
self.cfg_scale = cfg_scale
|
|
|
|
self.weights = weights
|
|
|
|
self.config = config
|
|
|
|
self.sampler_name = sampler_name
|
|
|
|
self.latent_channels = latent_channels
|
|
|
|
self.downsampling_factor = downsampling_factor
|
|
|
|
self.grid = grid
|
|
|
|
self.ddim_eta = ddim_eta
|
|
|
|
self.precision = precision
|
2022-09-03 13:43:18 +00:00
|
|
|
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
|
2022-09-01 14:16:05 +00:00
|
|
|
self.strength = strength
|
2022-09-03 06:13:31 +00:00
|
|
|
self.seamless = seamless
|
2022-09-01 14:16:05 +00:00
|
|
|
self.embedding_path = embedding_path
|
2022-09-01 21:54:01 +00:00
|
|
|
self.device_type = device_type
|
2022-09-01 14:16:05 +00:00
|
|
|
self.model = None # empty for now
|
|
|
|
self.sampler = None
|
|
|
|
self.device = None
|
2022-08-26 07:15:42 +00:00
|
|
|
self.latent_diffusion_weights = latent_diffusion_weights
|
2022-08-28 20:14:29 +00:00
|
|
|
|
2022-09-01 21:54:01 +00:00
|
|
|
if device_type == 'cuda' and not torch.cuda.is_available():
|
|
|
|
device_type = choose_torch_device()
|
|
|
|
print(">> cuda not available, using device", device_type)
|
|
|
|
self.device = torch.device(device_type)
|
2022-08-28 20:14:29 +00:00
|
|
|
|
2022-08-31 20:59:27 +00:00
|
|
|
# for VRAM usage statistics
|
2022-09-01 14:16:05 +00:00
|
|
|
device_type = choose_torch_device()
|
|
|
|
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
|
2022-09-01 03:13:21 +00:00
|
|
|
|
2022-08-17 01:34:37 +00:00
|
|
|
if seed is None:
|
|
|
|
self.seed = self._new_seed()
|
|
|
|
else:
|
|
|
|
self.seed = seed
|
2022-08-25 04:42:37 +00:00
|
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
def prompt2png(self, prompt, outdir, **kwargs):
|
|
|
|
"""
|
2022-08-25 04:42:37 +00:00
|
|
|
Takes a prompt and an output directory, writes out the requested number
|
|
|
|
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
|
|
|
|
Optional named arguments are the same as those passed to T2I and prompt2image()
|
2022-08-26 07:15:42 +00:00
|
|
|
"""
|
|
|
|
results = self.prompt2image(prompt, **kwargs)
|
2022-08-31 04:21:04 +00:00
|
|
|
pngwriter = PngWriter(outdir)
|
|
|
|
prefix = pngwriter.unique_prefix()
|
|
|
|
outputs = []
|
|
|
|
for image, seed in results:
|
|
|
|
name = f'{prefix}.{seed}.png'
|
2022-09-01 03:13:21 +00:00
|
|
|
path = pngwriter.save_image_and_prompt_to_png(
|
|
|
|
image, f'{prompt} -S{seed}', name)
|
2022-08-31 04:21:04 +00:00
|
|
|
outputs.append([path, seed])
|
|
|
|
return outputs
|
2022-08-25 04:42:37 +00:00
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
def txt2img(self, prompt, **kwargs):
|
2022-08-29 03:12:49 +00:00
|
|
|
outdir = kwargs.pop('outdir', 'outputs/img-samples')
|
2022-08-26 07:15:42 +00:00
|
|
|
return self.prompt2png(prompt, outdir, **kwargs)
|
|
|
|
|
|
|
|
def img2img(self, prompt, **kwargs):
|
2022-08-29 03:12:49 +00:00
|
|
|
outdir = kwargs.pop('outdir', 'outputs/img-samples')
|
2022-08-26 07:15:42 +00:00
|
|
|
assert (
|
|
|
|
'init_img' in kwargs
|
|
|
|
), 'call to img2img() must include the init_img argument'
|
|
|
|
return self.prompt2png(prompt, outdir, **kwargs)
|
|
|
|
|
|
|
|
def prompt2image(
|
2022-09-01 04:50:28 +00:00
|
|
|
self,
|
|
|
|
# these are common
|
|
|
|
prompt,
|
|
|
|
iterations = None,
|
|
|
|
steps = None,
|
|
|
|
seed = None,
|
|
|
|
cfg_scale = None,
|
|
|
|
ddim_eta = None,
|
|
|
|
skip_normalize = False,
|
|
|
|
image_callback = None,
|
|
|
|
step_callback = None,
|
|
|
|
width = None,
|
|
|
|
height = None,
|
2022-09-03 06:13:31 +00:00
|
|
|
seamless = False,
|
2022-09-01 04:50:28 +00:00
|
|
|
# these are specific to img2img
|
|
|
|
init_img = None,
|
|
|
|
fit = False,
|
|
|
|
strength = None,
|
|
|
|
gfpgan_strength= 0,
|
|
|
|
save_original = False,
|
|
|
|
upscale = None,
|
|
|
|
sampler_name = None,
|
|
|
|
log_tokenization= False,
|
2022-09-01 02:31:52 +00:00
|
|
|
with_variations = None,
|
|
|
|
variation_amount = 0.0,
|
2022-09-01 04:50:28 +00:00
|
|
|
**args,
|
2022-08-26 07:15:42 +00:00
|
|
|
): # eat up additional cruft
|
|
|
|
"""
|
2022-08-25 04:42:37 +00:00
|
|
|
ldm.prompt2image() is the common entry point for txt2img() and img2img()
|
|
|
|
It takes the following arguments:
|
|
|
|
prompt // prompt string (no default)
|
2022-08-31 02:30:12 +00:00
|
|
|
iterations // iterations (1); image count=iterations
|
2022-08-25 04:42:37 +00:00
|
|
|
steps // refinement steps per iteration
|
|
|
|
seed // seed for random number generator
|
|
|
|
width // width of image, in multiples of 64 (512)
|
|
|
|
height // height of image, in multiples of 64 (512)
|
|
|
|
cfg_scale // how strongly the prompt influences the image (7.5) (must be >1)
|
2022-09-03 06:13:31 +00:00
|
|
|
seamless // whether the generated image should tile
|
2022-08-25 04:42:37 +00:00
|
|
|
init_img // path to an initial image - its dimensions override width and height
|
|
|
|
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
2022-08-26 02:57:30 +00:00
|
|
|
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
|
2022-08-25 04:42:37 +00:00
|
|
|
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
2022-08-27 01:10:13 +00:00
|
|
|
step_callback // a function or method that will be called each step
|
2022-08-25 22:19:44 +00:00
|
|
|
image_callback // a function or method that will be called each time an image is generated
|
2022-09-01 02:31:52 +00:00
|
|
|
with_variations // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation
|
|
|
|
variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image)
|
2022-08-25 04:42:37 +00:00
|
|
|
|
2022-08-27 01:10:13 +00:00
|
|
|
To use the step callback, define a function that receives two arguments:
|
|
|
|
- Image GPU data
|
|
|
|
- The step number
|
|
|
|
|
|
|
|
To use the image callback, define a function of method that receives two arguments, an Image object
|
2022-08-25 22:13:07 +00:00
|
|
|
and the seed. You can then do whatever you like with the image, including converting it to
|
2022-08-25 04:42:37 +00:00
|
|
|
different formats and manipulating it. For example:
|
|
|
|
|
|
|
|
def process_image(image,seed):
|
|
|
|
image.save(f{'images/seed.png'})
|
|
|
|
|
|
|
|
The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code
|
|
|
|
to create the requested output directory, select a unique informative name for each image, and
|
|
|
|
write the prompt into the PNG metadata.
|
2022-08-26 07:15:42 +00:00
|
|
|
"""
|
2022-09-01 14:16:05 +00:00
|
|
|
# TODO: convert this into a getattr() loop
|
|
|
|
steps = steps or self.steps
|
|
|
|
width = width or self.width
|
|
|
|
height = height or self.height
|
2022-09-03 06:13:31 +00:00
|
|
|
seamless = seamless or self.seamless
|
2022-09-01 14:16:05 +00:00
|
|
|
cfg_scale = cfg_scale or self.cfg_scale
|
|
|
|
ddim_eta = ddim_eta or self.ddim_eta
|
|
|
|
iterations = iterations or self.iterations
|
|
|
|
strength = strength or self.strength
|
2022-08-29 03:28:49 +00:00
|
|
|
self.log_tokenization = log_tokenization
|
2022-09-01 02:31:52 +00:00
|
|
|
with_variations = [] if with_variations is None else with_variations
|
2022-08-26 07:15:42 +00:00
|
|
|
|
|
|
|
model = (
|
|
|
|
self.load_model()
|
|
|
|
) # will instantiate the model or return it from cache
|
2022-09-03 06:13:31 +00:00
|
|
|
for m in model.modules():
|
|
|
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
|
|
|
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
|
|
|
assert (
|
|
|
|
0.0 <= strength <= 1.0
|
|
|
|
), 'can only work with strength in [0.0, 1.0]'
|
2022-09-01 02:31:52 +00:00
|
|
|
assert (
|
|
|
|
0.0 <= variation_amount <= 1.0
|
|
|
|
), '-v --variation_amount must be in [0.0, 1.0]'
|
|
|
|
|
2022-09-04 06:34:20 +00:00
|
|
|
if len(with_variations) > 0 or variation_amount > 0.0:
|
2022-09-01 02:31:52 +00:00
|
|
|
assert seed is not None,\
|
|
|
|
'seed must be specified when using with_variations'
|
|
|
|
if variation_amount == 0.0:
|
|
|
|
assert iterations == 1,\
|
|
|
|
'when using --with_variations, multiple iterations are only possible when using --variation_amount'
|
|
|
|
assert all(0 <= weight <= 1 for _, weight in with_variations),\
|
|
|
|
f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}'
|
2022-08-30 19:26:02 +00:00
|
|
|
|
2022-09-02 21:54:55 +00:00
|
|
|
seed = seed or self.seed
|
2022-09-01 04:50:28 +00:00
|
|
|
width, height, _ = self._resolution_check(width, height, log=True)
|
2022-09-02 14:17:51 +00:00
|
|
|
|
|
|
|
# TODO: - Check if this is still necessary to run on M1 devices.
|
|
|
|
# - Move code into ldm.dream.devices to live alongside other
|
|
|
|
# special-hardware casing code.
|
2022-09-02 11:55:24 +00:00
|
|
|
if self.precision == 'autocast' and torch.cuda.is_available():
|
|
|
|
scope = autocast
|
|
|
|
else:
|
|
|
|
scope = nullcontext
|
2022-08-17 01:34:37 +00:00
|
|
|
|
2022-08-28 23:26:19 +00:00
|
|
|
if sampler_name and (sampler_name != self.sampler_name):
|
|
|
|
self.sampler_name = sampler_name
|
2022-08-28 04:05:00 +00:00
|
|
|
self._set_sampler()
|
|
|
|
|
2022-08-25 22:13:07 +00:00
|
|
|
tic = time.time()
|
Move environment-mac.yaml to Python 3.9 and patch dream.py for Macs.
I'm using stable-diffusion on a 2022 Macbook M2 Air with 24 GB unified memory.
I see this taking about 2.0s/it.
I've moved many deps from pip to conda-forge, to take advantage of the
precompiled binaries. Some notes for Mac users, since I've seen a lot of
confusion about this:
One doesn't need the `apple` channel to run this on a Mac-- that's only
used by `tensorflow-deps`, required for running tensorflow-metal. For
that, I have an example environment.yml here:
https://developer.apple.com/forums/thread/711792?answerId=723276022#723276022
However, the `CONDA_ENV=osx-arm64` environment variable *is* needed to
ensure that you do not run any Intel-specific packages such as `mkl`,
which will fail with [cryptic errors](https://github.com/CompVis/stable-diffusion/issues/25#issuecomment-1226702274)
on the ARM architecture and cause the environment to break.
I've also added a comment in the env file about 3.10 not working yet.
When it becomes possible to update, those commands run on an osx-arm64
machine should work to determine the new version set.
Here's what a successful run of dream.py should look like:
```
$ python scripts/dream.py --full_precision SIGABRT(6) ↵ 08:42:59
* Initializing, be patient...
Loading model from models/ldm/stable-diffusion-v1/model.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Using slower but more accurate full-precision math (--full_precision)
>> Setting Sampler to k_lms
model loaded in 6.12s
* Initialization done! Awaiting your command (-h for help, 'q' to quit)
dream> "an astronaut riding a horse"
Generating: 0%| | 0/1 [00:00<?, ?it/s]/Users/corajr/Documents/lstein/ldm/modules/embedding_manager.py:152: UserWarning: The operator 'aten::nonzero' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1662016319283/work/aten/src/ATen/mps/MPSFallback.mm:11.)
placeholder_idx = torch.where(
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:37<00:00, 1.95s/it]
Generating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [01:38<00:00, 98.55s/it]
Usage stats:
1 image(s) generated in 98.60s
Max VRAM used for this generation: 0.00G
Outputs:
outputs/img-samples/000001.1525943180.png: "an astronaut riding a horse" -s50 -W512 -H512 -C7.5 -Ak_lms -F -S1525943180
```
2022-09-01 01:18:19 +00:00
|
|
|
if torch.cuda.is_available():
|
2022-09-01 14:23:45 +00:00
|
|
|
torch.cuda.reset_peak_memory_stats()
|
2022-08-25 22:13:07 +00:00
|
|
|
results = list()
|
2022-08-23 04:51:38 +00:00
|
|
|
|
|
|
|
try:
|
2022-08-25 22:13:07 +00:00
|
|
|
if init_img:
|
2022-08-26 07:15:42 +00:00
|
|
|
assert os.path.exists(init_img), f'{init_img}: File not found'
|
2022-09-01 02:31:52 +00:00
|
|
|
init_image = self._load_img(init_img, width, height, fit).to(self.device)
|
2022-09-02 21:54:55 +00:00
|
|
|
with scope(self.device.type):
|
2022-09-01 02:31:52 +00:00
|
|
|
init_latent = self.model.get_first_stage_encoding(
|
|
|
|
self.model.encode_first_stage(init_image)
|
|
|
|
) # move to latent space
|
|
|
|
|
|
|
|
make_image = self._img2img(
|
2022-08-26 07:15:42 +00:00
|
|
|
prompt,
|
|
|
|
steps=steps,
|
|
|
|
cfg_scale=cfg_scale,
|
|
|
|
ddim_eta=ddim_eta,
|
|
|
|
skip_normalize=skip_normalize,
|
2022-09-01 02:31:52 +00:00
|
|
|
init_latent=init_latent,
|
2022-08-26 07:15:42 +00:00
|
|
|
strength=strength,
|
2022-08-27 01:10:13 +00:00
|
|
|
callback=step_callback,
|
2022-08-26 07:15:42 +00:00
|
|
|
)
|
2022-08-25 22:13:07 +00:00
|
|
|
else:
|
2022-09-04 06:34:20 +00:00
|
|
|
init_latent = None
|
2022-09-01 02:31:52 +00:00
|
|
|
make_image = self._txt2img(
|
2022-08-26 07:15:42 +00:00
|
|
|
prompt,
|
|
|
|
steps=steps,
|
|
|
|
cfg_scale=cfg_scale,
|
|
|
|
ddim_eta=ddim_eta,
|
|
|
|
skip_normalize=skip_normalize,
|
|
|
|
width=width,
|
|
|
|
height=height,
|
2022-08-27 01:10:13 +00:00
|
|
|
callback=step_callback,
|
2022-08-26 07:15:42 +00:00
|
|
|
)
|
2022-08-25 22:19:44 +00:00
|
|
|
|
2022-09-01 02:31:52 +00:00
|
|
|
initial_noise = None
|
|
|
|
if variation_amount > 0 or len(with_variations) > 0:
|
|
|
|
# use fixed initial noise plus random noise per iteration
|
|
|
|
seed_everything(seed)
|
2022-09-04 06:34:20 +00:00
|
|
|
initial_noise = self._get_noise(init_latent,width,height)
|
2022-09-01 02:31:52 +00:00
|
|
|
for v_seed, v_weight in with_variations:
|
|
|
|
seed = v_seed
|
|
|
|
seed_everything(seed)
|
2022-09-04 06:34:20 +00:00
|
|
|
next_noise = self._get_noise(init_latent,width,height)
|
2022-09-01 02:31:52 +00:00
|
|
|
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
|
|
|
|
if variation_amount > 0:
|
|
|
|
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
|
|
|
seed = random.randrange(0,np.iinfo(np.uint32).max)
|
|
|
|
|
2022-09-01 05:21:14 +00:00
|
|
|
device_type = choose_autocast_device(self.device)
|
2022-08-31 07:32:07 +00:00
|
|
|
with scope(device_type), self.model.ema_scope():
|
2022-08-28 20:14:29 +00:00
|
|
|
for n in trange(iterations, desc='Generating'):
|
2022-09-01 02:31:52 +00:00
|
|
|
x_T = None
|
|
|
|
if variation_amount > 0:
|
|
|
|
seed_everything(seed)
|
2022-09-04 06:34:20 +00:00
|
|
|
target_noise = self._get_noise(init_latent,width,height)
|
2022-09-01 02:31:52 +00:00
|
|
|
x_T = self.slerp(variation_amount, initial_noise, target_noise)
|
|
|
|
elif initial_noise is not None:
|
|
|
|
# i.e. we specified particular variations
|
|
|
|
x_T = initial_noise
|
|
|
|
else:
|
|
|
|
seed_everything(seed)
|
2022-09-03 14:11:46 +00:00
|
|
|
if self.device.type == 'mps':
|
2022-09-04 06:34:20 +00:00
|
|
|
x_T = self._get_noise(init_latent,width,height)
|
2022-09-01 02:31:52 +00:00
|
|
|
# make_image will do the equivalent of get_noise itself
|
|
|
|
image = make_image(x_T)
|
2022-08-31 02:30:12 +00:00
|
|
|
results.append([image, seed])
|
|
|
|
if image_callback is not None:
|
|
|
|
image_callback(image, seed)
|
2022-08-28 20:14:29 +00:00
|
|
|
seed = self._new_seed()
|
|
|
|
|
|
|
|
if upscale is not None or gfpgan_strength > 0:
|
|
|
|
for result in results:
|
|
|
|
image, seed = result
|
2022-08-26 04:39:57 +00:00
|
|
|
try:
|
2022-08-28 20:14:29 +00:00
|
|
|
if upscale is not None:
|
|
|
|
from ldm.gfpgan.gfpgan_tools import (
|
|
|
|
real_esrgan_upscale,
|
|
|
|
)
|
2022-08-28 21:26:39 +00:00
|
|
|
if len(upscale) < 2:
|
|
|
|
upscale.append(0.75)
|
2022-08-28 20:14:29 +00:00
|
|
|
image = real_esrgan_upscale(
|
|
|
|
image,
|
|
|
|
upscale[1],
|
|
|
|
int(upscale[0]),
|
|
|
|
prompt,
|
|
|
|
seed,
|
|
|
|
)
|
|
|
|
if gfpgan_strength > 0:
|
|
|
|
from ldm.gfpgan.gfpgan_tools import _run_gfpgan
|
|
|
|
|
|
|
|
image = _run_gfpgan(
|
|
|
|
image, gfpgan_strength, prompt, seed, 1
|
2022-08-26 07:15:42 +00:00
|
|
|
)
|
2022-08-26 04:39:57 +00:00
|
|
|
except Exception as e:
|
2022-08-26 07:15:42 +00:00
|
|
|
print(
|
2022-09-01 21:54:01 +00:00
|
|
|
f'>> Error running RealESRGAN - Your image was not upscaled.\n{e}'
|
2022-08-26 07:15:42 +00:00
|
|
|
)
|
2022-08-26 00:16:07 +00:00
|
|
|
if image_callback is not None:
|
2022-09-04 11:17:58 +00:00
|
|
|
image_callback(image, seed, upscaled=True)
|
2022-08-31 02:49:00 +00:00
|
|
|
else: # no callback passed, so we simply replace old image with rescaled one
|
2022-08-29 03:40:04 +00:00
|
|
|
result[0] = image
|
2022-08-25 22:19:44 +00:00
|
|
|
|
2022-08-23 04:51:38 +00:00
|
|
|
except KeyboardInterrupt:
|
|
|
|
print('*interrupted*')
|
2022-08-26 07:15:42 +00:00
|
|
|
print(
|
2022-09-01 21:54:01 +00:00
|
|
|
'>> Partial results will be returned; if --grid was requested, nothing will be returned.'
|
2022-08-26 07:15:42 +00:00
|
|
|
)
|
2022-08-23 04:51:38 +00:00
|
|
|
except RuntimeError as e:
|
2022-09-01 21:54:01 +00:00
|
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
|
|
print('>> Are you sure your system has an adequate NVIDIA GPU?')
|
2022-08-25 22:13:07 +00:00
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
toc = time.time()
|
2022-09-01 04:50:28 +00:00
|
|
|
print('>> Usage stats:')
|
2022-08-28 20:14:29 +00:00
|
|
|
print(
|
2022-09-01 04:50:28 +00:00
|
|
|
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
|
2022-08-28 20:14:29 +00:00
|
|
|
)
|
|
|
|
print(
|
2022-09-01 04:50:28 +00:00
|
|
|
f'>> Max VRAM used for this generation:',
|
2022-08-28 20:14:29 +00:00
|
|
|
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
|
|
|
)
|
2022-08-31 20:59:27 +00:00
|
|
|
|
|
|
|
if self.session_peakmem:
|
|
|
|
self.session_peakmem = max(
|
|
|
|
self.session_peakmem, torch.cuda.max_memory_allocated()
|
|
|
|
)
|
|
|
|
print(
|
2022-09-01 04:50:28 +00:00
|
|
|
f'>> Max VRAM used since script start: ',
|
2022-08-31 20:59:27 +00:00
|
|
|
'%4.2fG' % (self.session_peakmem / 1e9),
|
|
|
|
)
|
2022-08-24 21:52:34 +00:00
|
|
|
return results
|
2022-08-17 01:34:37 +00:00
|
|
|
|
2022-08-23 17:49:17 +00:00
|
|
|
@torch.no_grad()
|
2022-08-26 07:15:42 +00:00
|
|
|
def _txt2img(
|
|
|
|
self,
|
|
|
|
prompt,
|
|
|
|
steps,
|
|
|
|
cfg_scale,
|
|
|
|
ddim_eta,
|
|
|
|
skip_normalize,
|
|
|
|
width,
|
|
|
|
height,
|
2022-08-27 01:10:13 +00:00
|
|
|
callback,
|
2022-08-26 07:15:42 +00:00
|
|
|
):
|
2022-08-18 14:47:53 +00:00
|
|
|
"""
|
2022-09-01 02:31:52 +00:00
|
|
|
Returns a function returning an image derived from the prompt and the initial image
|
|
|
|
Return value depends on the seed at the time you call it
|
2022-08-24 21:52:34 +00:00
|
|
|
"""
|
2022-08-17 01:34:37 +00:00
|
|
|
|
2022-08-25 22:19:44 +00:00
|
|
|
sampler = self.sampler
|
|
|
|
|
2022-09-01 02:31:52 +00:00
|
|
|
def make_image(x_T):
|
2022-08-31 02:30:12 +00:00
|
|
|
uc, c = self._get_uc_and_c(prompt, skip_normalize)
|
2022-08-26 07:15:42 +00:00
|
|
|
shape = [
|
|
|
|
self.latent_channels,
|
|
|
|
height // self.downsampling_factor,
|
|
|
|
width // self.downsampling_factor,
|
|
|
|
]
|
|
|
|
samples, _ = sampler.sample(
|
2022-08-31 02:30:12 +00:00
|
|
|
batch_size=1,
|
2022-08-26 07:15:42 +00:00
|
|
|
S=steps,
|
2022-09-01 02:31:52 +00:00
|
|
|
x_T=x_T,
|
2022-08-26 07:15:42 +00:00
|
|
|
conditioning=c,
|
|
|
|
shape=shape,
|
|
|
|
verbose=False,
|
|
|
|
unconditional_guidance_scale=cfg_scale,
|
|
|
|
unconditional_conditioning=uc,
|
|
|
|
eta=ddim_eta,
|
2022-08-27 01:10:13 +00:00
|
|
|
img_callback=callback
|
2022-08-26 07:15:42 +00:00
|
|
|
)
|
2022-09-01 02:31:52 +00:00
|
|
|
return self._sample_to_image(samples)
|
|
|
|
return make_image
|
2022-08-25 22:13:07 +00:00
|
|
|
|
2022-08-23 17:49:17 +00:00
|
|
|
@torch.no_grad()
|
2022-08-26 07:15:42 +00:00
|
|
|
def _img2img(
|
2022-09-01 04:50:28 +00:00
|
|
|
self,
|
|
|
|
prompt,
|
|
|
|
steps,
|
|
|
|
cfg_scale,
|
|
|
|
ddim_eta,
|
|
|
|
skip_normalize,
|
2022-09-01 02:31:52 +00:00
|
|
|
init_latent,
|
2022-09-01 04:50:28 +00:00
|
|
|
strength,
|
|
|
|
callback, # Currently not implemented for img2img
|
2022-08-26 07:15:42 +00:00
|
|
|
):
|
2022-08-18 14:47:53 +00:00
|
|
|
"""
|
2022-09-01 02:31:52 +00:00
|
|
|
Returns a function returning an image derived from the prompt and the initial image
|
|
|
|
Return value depends on the seed at the time you call it
|
2022-08-18 14:47:53 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
# PLMS sampler not supported yet, so ignore previous sampler
|
2022-08-26 07:15:42 +00:00
|
|
|
if self.sampler_name != 'ddim':
|
|
|
|
print(
|
2022-09-01 04:50:28 +00:00
|
|
|
f">> sampler '{self.sampler_name}' is not yet supported. Using DDIM sampler"
|
2022-08-26 07:15:42 +00:00
|
|
|
)
|
2022-08-24 23:47:59 +00:00
|
|
|
sampler = DDIMSampler(self.model, device=self.device)
|
2022-08-18 14:47:53 +00:00
|
|
|
else:
|
|
|
|
sampler = self.sampler
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
sampler.make_schedule(
|
|
|
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
|
|
|
)
|
2022-08-25 22:13:07 +00:00
|
|
|
|
2022-08-18 14:47:53 +00:00
|
|
|
t_enc = int(strength * steps)
|
2022-08-23 04:51:38 +00:00
|
|
|
|
2022-09-01 02:31:52 +00:00
|
|
|
def make_image(x_T):
|
2022-08-31 02:30:12 +00:00
|
|
|
uc, c = self._get_uc_and_c(prompt, skip_normalize)
|
2022-08-25 22:19:44 +00:00
|
|
|
|
|
|
|
# encode (scaled latent)
|
2022-08-26 07:15:42 +00:00
|
|
|
z_enc = sampler.stochastic_encode(
|
2022-09-01 02:31:52 +00:00
|
|
|
init_latent,
|
|
|
|
torch.tensor([t_enc]).to(self.device),
|
|
|
|
noise=x_T
|
2022-08-26 07:15:42 +00:00
|
|
|
)
|
2022-08-25 22:19:44 +00:00
|
|
|
# decode it
|
2022-08-26 07:15:42 +00:00
|
|
|
samples = sampler.decode(
|
|
|
|
z_enc,
|
|
|
|
c,
|
|
|
|
t_enc,
|
2022-08-30 19:36:12 +00:00
|
|
|
img_callback=callback,
|
2022-08-26 07:15:42 +00:00
|
|
|
unconditional_guidance_scale=cfg_scale,
|
|
|
|
unconditional_conditioning=uc,
|
|
|
|
)
|
2022-09-01 02:31:52 +00:00
|
|
|
return self._sample_to_image(samples)
|
|
|
|
return make_image
|
2022-08-25 22:19:44 +00:00
|
|
|
|
|
|
|
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
|
2022-08-31 02:30:12 +00:00
|
|
|
def _get_uc_and_c(self, prompt, skip_normalize):
|
2022-08-25 22:19:44 +00:00
|
|
|
|
2022-08-31 02:30:12 +00:00
|
|
|
uc = self.model.get_learned_conditioning([''])
|
2022-08-25 22:19:44 +00:00
|
|
|
|
2022-08-31 18:00:10 +00:00
|
|
|
# get weighted sub-prompts
|
2022-09-01 03:13:21 +00:00
|
|
|
weighted_subprompts = T2I._split_weighted_subprompts(
|
|
|
|
prompt, skip_normalize)
|
2022-08-31 18:00:10 +00:00
|
|
|
|
|
|
|
if len(weighted_subprompts) > 1:
|
2022-08-25 22:19:44 +00:00
|
|
|
# i dont know if this is correct.. but it works
|
|
|
|
c = torch.zeros_like(uc)
|
|
|
|
# normalize each "sub prompt" and add it
|
2022-09-01 02:31:52 +00:00
|
|
|
for subprompt, weight in weighted_subprompts:
|
2022-08-31 18:00:10 +00:00
|
|
|
self._log_tokenization(subprompt)
|
2022-08-26 07:15:42 +00:00
|
|
|
c = torch.add(
|
|
|
|
c,
|
2022-08-31 18:00:10 +00:00
|
|
|
self.model.get_learned_conditioning([subprompt]),
|
2022-08-26 07:15:42 +00:00
|
|
|
alpha=weight,
|
|
|
|
)
|
|
|
|
else: # just standard 1 prompt
|
2022-08-29 03:28:49 +00:00
|
|
|
self._log_tokenization(prompt)
|
2022-08-31 02:30:12 +00:00
|
|
|
c = self.model.get_learned_conditioning([prompt])
|
2022-08-25 22:19:44 +00:00
|
|
|
return (uc, c)
|
|
|
|
|
2022-08-31 02:30:12 +00:00
|
|
|
def _sample_to_image(self, samples):
|
2022-08-25 22:19:44 +00:00
|
|
|
x_samples = self.model.decode_first_stage(samples)
|
|
|
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
2022-08-31 02:30:12 +00:00
|
|
|
if len(x_samples) != 1:
|
2022-08-31 02:49:00 +00:00
|
|
|
raise Exception(
|
2022-09-01 21:54:01 +00:00
|
|
|
f'>> expected to get a single image, but got {len(x_samples)}')
|
2022-08-31 02:30:12 +00:00
|
|
|
x_sample = 255.0 * rearrange(
|
|
|
|
x_samples[0].cpu().numpy(), 'c h w -> h w c'
|
|
|
|
)
|
|
|
|
return Image.fromarray(x_sample.astype(np.uint8))
|
2022-08-17 01:34:37 +00:00
|
|
|
|
|
|
|
def _new_seed(self):
|
2022-08-26 07:15:42 +00:00
|
|
|
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
2022-08-17 01:34:37 +00:00
|
|
|
return self.seed
|
|
|
|
|
|
|
|
def load_model(self):
|
2022-08-26 07:15:42 +00:00
|
|
|
"""Load and initialize the model from configuration variables passed at object creation time"""
|
2022-08-17 01:34:37 +00:00
|
|
|
if self.model is None:
|
|
|
|
seed_everything(self.seed)
|
|
|
|
try:
|
2022-09-01 03:13:21 +00:00
|
|
|
config = OmegaConf.load(self.config)
|
2022-08-26 07:15:42 +00:00
|
|
|
model = self._load_model_from_config(config, self.weights)
|
2022-08-24 15:29:32 +00:00
|
|
|
if self.embedding_path is not None:
|
2022-08-28 20:14:29 +00:00
|
|
|
model.embedding_manager.load(
|
|
|
|
self.embedding_path, self.full_precision
|
|
|
|
)
|
2022-08-17 01:34:37 +00:00
|
|
|
self.model = model.to(self.device)
|
2022-08-24 05:39:25 +00:00
|
|
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
|
|
|
self.model.cond_stage_model.device = self.device
|
2022-09-01 02:24:23 +00:00
|
|
|
except AttributeError as e:
|
2022-09-01 21:54:01 +00:00
|
|
|
print(f'>> Error loading model. {str(e)}', file=sys.stderr)
|
2022-08-31 02:49:00 +00:00
|
|
|
print(traceback.format_exc(), file=sys.stderr)
|
2022-09-01 02:24:23 +00:00
|
|
|
raise SystemExit from e
|
2022-08-17 01:34:37 +00:00
|
|
|
|
2022-08-28 04:05:00 +00:00
|
|
|
self._set_sampler()
|
2022-08-23 21:25:39 +00:00
|
|
|
|
2022-09-03 06:13:31 +00:00
|
|
|
for m in self.model.modules():
|
|
|
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
|
|
|
m._orig_padding_mode = m.padding_mode
|
|
|
|
|
2022-08-17 01:34:37 +00:00
|
|
|
return self.model
|
2022-08-25 22:13:07 +00:00
|
|
|
|
2022-09-03 14:40:20 +00:00
|
|
|
# returns a tensor filled with random numbers from a normal distribution
|
2022-09-04 06:34:20 +00:00
|
|
|
def _get_noise(self,init_latent,width,height):
|
|
|
|
if init_latent is not None:
|
2022-09-03 14:40:20 +00:00
|
|
|
if self.device.type == 'mps':
|
|
|
|
return torch.randn_like(init_latent, device='cpu').to(self.device)
|
|
|
|
else:
|
|
|
|
return torch.randn_like(init_latent, device=self.device)
|
|
|
|
else:
|
|
|
|
if self.device.type == 'mps':
|
|
|
|
return torch.randn([1,
|
|
|
|
self.latent_channels,
|
|
|
|
height // self.downsampling_factor,
|
|
|
|
width // self.downsampling_factor],
|
|
|
|
device='cpu').to(self.device)
|
|
|
|
else:
|
|
|
|
return torch.randn([1,
|
|
|
|
self.latent_channels,
|
|
|
|
height // self.downsampling_factor,
|
|
|
|
width // self.downsampling_factor],
|
|
|
|
device=self.device)
|
|
|
|
|
2022-08-28 04:05:00 +00:00
|
|
|
def _set_sampler(self):
|
|
|
|
msg = f'>> Setting Sampler to {self.sampler_name}'
|
|
|
|
if self.sampler_name == 'plms':
|
|
|
|
self.sampler = PLMSSampler(self.model, device=self.device)
|
|
|
|
elif self.sampler_name == 'ddim':
|
|
|
|
self.sampler = DDIMSampler(self.model, device=self.device)
|
|
|
|
elif self.sampler_name == 'k_dpm_2_a':
|
|
|
|
self.sampler = KSampler(
|
|
|
|
self.model, 'dpm_2_ancestral', device=self.device
|
|
|
|
)
|
|
|
|
elif self.sampler_name == 'k_dpm_2':
|
|
|
|
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
|
|
|
|
elif self.sampler_name == 'k_euler_a':
|
|
|
|
self.sampler = KSampler(
|
|
|
|
self.model, 'euler_ancestral', device=self.device
|
|
|
|
)
|
|
|
|
elif self.sampler_name == 'k_euler':
|
|
|
|
self.sampler = KSampler(self.model, 'euler', device=self.device)
|
|
|
|
elif self.sampler_name == 'k_heun':
|
|
|
|
self.sampler = KSampler(self.model, 'heun', device=self.device)
|
|
|
|
elif self.sampler_name == 'k_lms':
|
|
|
|
self.sampler = KSampler(self.model, 'lms', device=self.device)
|
|
|
|
else:
|
|
|
|
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
|
|
|
|
self.sampler = PLMSSampler(self.model, device=self.device)
|
|
|
|
|
|
|
|
print(msg)
|
|
|
|
|
2022-08-17 01:34:37 +00:00
|
|
|
def _load_model_from_config(self, config, ckpt):
|
2022-09-01 04:50:28 +00:00
|
|
|
print(f'>> Loading model from {ckpt}')
|
2022-08-26 07:15:42 +00:00
|
|
|
pl_sd = torch.load(ckpt, map_location='cpu')
|
|
|
|
# if "global_step" in pl_sd:
|
|
|
|
# print(f"Global Step: {pl_sd['global_step']}")
|
|
|
|
sd = pl_sd['state_dict']
|
2022-08-17 01:34:37 +00:00
|
|
|
model = instantiate_from_config(config.model)
|
|
|
|
m, u = model.load_state_dict(sd, strict=False)
|
2022-08-25 16:18:35 +00:00
|
|
|
model.to(self.device)
|
2022-08-17 01:34:37 +00:00
|
|
|
model.eval()
|
2022-08-21 23:57:48 +00:00
|
|
|
if self.full_precision:
|
2022-08-26 07:15:42 +00:00
|
|
|
print(
|
|
|
|
'Using slower but more accurate full-precision math (--full_precision)'
|
|
|
|
)
|
2022-08-21 23:57:48 +00:00
|
|
|
else:
|
2022-08-26 07:15:42 +00:00
|
|
|
print(
|
2022-09-01 04:50:28 +00:00
|
|
|
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
|
2022-08-26 07:15:42 +00:00
|
|
|
)
|
2022-08-21 23:57:48 +00:00
|
|
|
model.half()
|
2022-08-17 01:34:37 +00:00
|
|
|
return model
|
|
|
|
|
2022-09-01 04:50:28 +00:00
|
|
|
def _load_img(self, path, width, height, fit=False):
|
2022-08-26 12:13:16 +00:00
|
|
|
with Image.open(path) as img:
|
2022-08-28 20:14:29 +00:00
|
|
|
image = img.convert('RGB')
|
2022-08-31 02:49:00 +00:00
|
|
|
print(
|
2022-09-01 04:50:28 +00:00
|
|
|
f'>> loaded input image of size {image.width}x{image.height} from {path}'
|
|
|
|
)
|
2022-09-01 02:31:52 +00:00
|
|
|
|
2022-09-01 04:50:28 +00:00
|
|
|
# The logic here is:
|
|
|
|
# 1. If "fit" is true, then the image will be fit into the bounding box defined
|
|
|
|
# by width and height. It will do this in a way that preserves the init image's
|
|
|
|
# aspect ratio while preventing letterboxing. This means that if there is
|
|
|
|
# leftover horizontal space after rescaling the image to fit in the bounding box,
|
|
|
|
# the generated image's width will be reduced to the rescaled init image's width.
|
|
|
|
# Similarly for the vertical space.
|
|
|
|
# 2. Otherwise, if "fit" is false, then the image will be scaled, preserving its
|
|
|
|
# aspect ratio, to the nearest multiple of 64. Large images may generate an
|
|
|
|
# unexpected OOM error.
|
|
|
|
if fit:
|
|
|
|
image = self._fit_image(image,(width,height))
|
2022-08-31 02:49:00 +00:00
|
|
|
else:
|
2022-09-01 04:50:28 +00:00
|
|
|
image = self._squeeze_image(image)
|
2022-08-18 14:47:53 +00:00
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
|
|
image = image[None].transpose(0, 3, 1, 2)
|
|
|
|
image = torch.from_numpy(image)
|
2022-08-26 07:15:42 +00:00
|
|
|
return 2.0 * image - 1.0
|
2022-08-23 03:56:36 +00:00
|
|
|
|
2022-09-01 04:50:28 +00:00
|
|
|
def _squeeze_image(self,image):
|
|
|
|
x,y,resize_needed = self._resolution_check(image.width,image.height)
|
|
|
|
if resize_needed:
|
|
|
|
return InitImageResizer(image).resize(x,y)
|
|
|
|
return image
|
2022-09-01 02:31:52 +00:00
|
|
|
|
2022-09-01 04:50:28 +00:00
|
|
|
|
|
|
|
def _fit_image(self,image,max_dimensions):
|
|
|
|
w,h = max_dimensions
|
|
|
|
print(
|
|
|
|
f'>> image will be resized to fit inside a box {w}x{h} in size.'
|
|
|
|
)
|
|
|
|
if image.width > image.height:
|
|
|
|
h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
|
|
|
|
elif image.height > image.width:
|
|
|
|
w = None # ditto for w
|
|
|
|
else:
|
|
|
|
pass
|
|
|
|
image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally
|
|
|
|
print(
|
|
|
|
f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}'
|
|
|
|
)
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
# TO DO: Move this and related weighted subprompt code into its own module.
|
2022-08-31 18:00:10 +00:00
|
|
|
def _split_weighted_subprompts(text, skip_normalize=False):
|
2022-08-23 05:23:14 +00:00
|
|
|
"""
|
2022-08-25 22:13:07 +00:00
|
|
|
grabs all text up to the first occurrence of ':'
|
2022-08-23 05:23:14 +00:00
|
|
|
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
|
|
|
if ':' has no value defined, defaults to 1.0
|
|
|
|
repeats until no text remaining
|
|
|
|
"""
|
2022-08-31 18:00:10 +00:00
|
|
|
prompt_parser = re.compile("""
|
|
|
|
(?P<prompt> # capture group for 'prompt'
|
|
|
|
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
|
|
|
) # end 'prompt'
|
|
|
|
(?: # non-capture group
|
2022-09-01 02:31:52 +00:00
|
|
|
:+ # match one or more ':' characters
|
2022-08-31 18:00:10 +00:00
|
|
|
(?P<weight> # capture group for 'weight'
|
|
|
|
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
2022-09-01 02:31:52 +00:00
|
|
|
)? # end weight capture group, make optional
|
2022-08-31 18:00:10 +00:00
|
|
|
\s* # strip spaces after weight
|
|
|
|
| # OR
|
|
|
|
$ # else, if no ':' then match end of line
|
|
|
|
) # end non-capture group
|
|
|
|
""", re.VERBOSE)
|
2022-09-01 03:13:21 +00:00
|
|
|
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
|
|
|
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
2022-08-31 18:00:10 +00:00
|
|
|
if skip_normalize:
|
|
|
|
return parsed_prompts
|
|
|
|
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
|
|
|
if weight_sum == 0:
|
2022-09-01 03:13:21 +00:00
|
|
|
print(
|
|
|
|
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
2022-08-31 18:00:10 +00:00
|
|
|
equal_weight = 1 / len(parsed_prompts)
|
|
|
|
return [(x[0], equal_weight) for x in parsed_prompts]
|
|
|
|
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
2022-09-01 03:13:21 +00:00
|
|
|
|
|
|
|
# shows how the prompt is tokenized
|
|
|
|
# usually tokens have '</w>' to indicate end-of-word,
|
2022-08-29 03:28:49 +00:00
|
|
|
# but for readability it has been replaced with ' '
|
|
|
|
def _log_tokenization(self, text):
|
|
|
|
if not self.log_tokenization:
|
|
|
|
return
|
|
|
|
tokens = self.model.cond_stage_model.tokenizer._tokenize(text)
|
|
|
|
tokenized = ""
|
|
|
|
discarded = ""
|
|
|
|
usedTokens = 0
|
|
|
|
totalTokens = len(tokens)
|
2022-08-31 02:49:00 +00:00
|
|
|
for i in range(0, totalTokens):
|
|
|
|
token = tokens[i].replace('</w>', ' ')
|
2022-08-29 03:28:49 +00:00
|
|
|
# alternate color
|
|
|
|
s = (usedTokens % 6) + 1
|
|
|
|
if i < self.model.cond_stage_model.max_length:
|
|
|
|
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
|
|
|
usedTokens += 1
|
2022-08-31 02:49:00 +00:00
|
|
|
else: # over max token length
|
2022-08-29 03:28:49 +00:00
|
|
|
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
|
|
|
print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m")
|
|
|
|
if discarded != "":
|
2022-08-31 02:49:00 +00:00
|
|
|
print(
|
|
|
|
f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")
|
|
|
|
|
|
|
|
def _resolution_check(self, width, height, log=False):
|
|
|
|
resize_needed = False
|
|
|
|
w, h = map(
|
|
|
|
lambda x: x - x % 64, (width, height)
|
|
|
|
) # resize to integer multiple of 64
|
|
|
|
if h != height or w != width:
|
|
|
|
if log:
|
|
|
|
print(
|
|
|
|
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
|
|
|
|
)
|
|
|
|
height = h
|
2022-09-01 04:50:28 +00:00
|
|
|
width = w
|
2022-08-31 02:49:00 +00:00
|
|
|
resize_needed = True
|
2022-09-01 03:13:21 +00:00
|
|
|
|
|
|
|
if (width * height) > (self.width * self.height):
|
|
|
|
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
|
|
|
|
2022-08-31 02:49:00 +00:00
|
|
|
return width, height, resize_needed
|
2022-09-01 02:31:52 +00:00
|
|
|
|
|
|
|
|
|
|
|
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
|
|
|
'''
|
|
|
|
Spherical linear interpolation
|
|
|
|
Args:
|
|
|
|
t (float/np.ndarray): Float value between 0.0 and 1.0
|
|
|
|
v0 (np.ndarray): Starting vector
|
|
|
|
v1 (np.ndarray): Final vector
|
|
|
|
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
|
|
|
colineal. Not recommended to alter this.
|
|
|
|
Returns:
|
|
|
|
v2 (np.ndarray): Interpolation vector between v0 and v1
|
|
|
|
'''
|
|
|
|
inputs_are_torch = False
|
|
|
|
if not isinstance(v0, np.ndarray):
|
|
|
|
inputs_are_torch = True
|
|
|
|
v0 = v0.detach().cpu().numpy()
|
|
|
|
if not isinstance(v1, np.ndarray):
|
|
|
|
inputs_are_torch = True
|
|
|
|
v1 = v1.detach().cpu().numpy()
|
|
|
|
|
|
|
|
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
|
|
|
if np.abs(dot) > DOT_THRESHOLD:
|
|
|
|
v2 = (1 - t) * v0 + t * v1
|
|
|
|
else:
|
|
|
|
theta_0 = np.arccos(dot)
|
|
|
|
sin_theta_0 = np.sin(theta_0)
|
|
|
|
theta_t = theta_0 * t
|
|
|
|
sin_theta_t = np.sin(theta_t)
|
|
|
|
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
|
|
|
s1 = sin_theta_t / sin_theta_0
|
|
|
|
v2 = s0 * v0 + s1 * v1
|
|
|
|
|
|
|
|
if inputs_are_torch:
|
|
|
|
v2 = torch.from_numpy(v2).to(self.device)
|
|
|
|
|
|
|
|
return v2
|