mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactoring simplet2i (#387)
* start refactoring -not yet functional * first phase of refactor done - not sure weighted prompts working * Second phase of refactoring. Everything mostly working. * The refactoring has moved all the hard-core inference work into ldm.dream.generator.*, where there are submodules for txt2img and img2img. inpaint will go in there as well. * Some additional refactoring will be done soon, but relatively minor work. * fix -save_orig flag to actually work * add @neonsecret attention.py memory optimization * remove unneeded imports * move token logging into conditioning.py * add placeholder version of inpaint; porting in progress * fix crash in img2img * inpainting working; not tested on variations * fix crashes in img2img * ported attention.py memory optimization #117 from basujindal branch * added @torch_no_grad() decorators to img2img, txt2img, inpaint closures * Final commit prior to PR against development * fixup crash when generating intermediate images in web UI * rename ldm.simplet2i to ldm.generate * add backward-compatibility simplet2i shell with deprecation warning * add back in mps exception, addresses @vargol comment in #354 * replaced Conditioning class with exported functions * fix wrong type of with_variations attribute during intialization * changed "image_iterator()" to "get_make_image()" * raise NotImplementedError for calling get_make_image() in parent class * Update ldm/generate.py better error message Co-authored-by: Kevin Gibbons <bakkot@gmail.com> * minor stylistic fixes and assertion checks from code review * moved get_noise() method into img2img class * break get_noise() into two methods, one for txt2img and the other for img2img * inpainting works on non-square images now * make get_noise() an abstract method in base class * much improved inpainting Co-authored-by: Kevin Gibbons <bakkot@gmail.com>
This commit is contained in:
parent
1ad2a8e567
commit
720e5cd651
96
ldm/dream/conditioning.py
Normal file
96
ldm/dream/conditioning.py
Normal file
@ -0,0 +1,96 @@
|
||||
'''
|
||||
This module handles the generation of the conditioning tensors, including management of
|
||||
weighted subprompts.
|
||||
|
||||
Useful function exports:
|
||||
|
||||
get_uc_and_c() get the conditioned and unconditioned latent
|
||||
split_weighted_subpromopts() split subprompts, normalize and weight them
|
||||
log_tokenization() print out colour-coded tokens and warn if truncated
|
||||
|
||||
'''
|
||||
import re
|
||||
import torch
|
||||
|
||||
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
|
||||
uc = model.get_learned_conditioning([''])
|
||||
|
||||
# get weighted sub-prompts
|
||||
weighted_subprompts = split_weighted_subprompts(
|
||||
prompt, skip_normalize
|
||||
)
|
||||
|
||||
if len(weighted_subprompts) > 1:
|
||||
# i dont know if this is correct.. but it works
|
||||
c = torch.zeros_like(uc)
|
||||
# normalize each "sub prompt" and add it
|
||||
for subprompt, weight in weighted_subprompts:
|
||||
log_tokenization(subprompt, model, log_tokens)
|
||||
c = torch.add(
|
||||
c,
|
||||
model.get_learned_conditioning([subprompt]),
|
||||
alpha=weight,
|
||||
)
|
||||
else: # just standard 1 prompt
|
||||
log_tokenization(prompt, model, log_tokens)
|
||||
c = model.get_learned_conditioning([prompt])
|
||||
return (uc, c)
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
"""
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile("""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
(?: # non-capture group
|
||||
:+ # match one or more ':' characters
|
||||
(?P<weight> # capture group for 'weight'
|
||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||
)? # end weight capture group, make optional
|
||||
\s* # strip spaces after weight
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""", re.VERBOSE)
|
||||
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
||||
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||||
equal_weight = 1 / len(parsed_prompts)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||
|
||||
# shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
def log_tokenization(text, model, log=False):
|
||||
if not log:
|
||||
return
|
||||
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', ' ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < model.cond_stage_model.max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
@ -1,4 +1,6 @@
|
||||
import torch
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
def choose_torch_device() -> str:
|
||||
'''Convenience routine for guessing which GPU device to run model on'''
|
||||
@ -8,10 +10,11 @@ def choose_torch_device() -> str:
|
||||
return 'mps'
|
||||
return 'cpu'
|
||||
|
||||
def choose_autocast_device(device) -> str:
|
||||
def choose_autocast_device(device):
|
||||
'''Returns an autocast compatible device from a torch device'''
|
||||
device_type = device.type # this returns 'mps' on M1
|
||||
# autocast only supports cuda or cpu
|
||||
if device_type not in ('cuda','cpu'):
|
||||
return 'cpu'
|
||||
return device_type
|
||||
if device_type in ('cuda','cpu'):
|
||||
return device_type,autocast
|
||||
else:
|
||||
return 'cpu',nullcontext
|
||||
|
4
ldm/dream/generator/__init__.py
Normal file
4
ldm/dream/generator/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for the ldm.dream.generator package
|
||||
'''
|
||||
from .base import Generator
|
158
ldm/dream/generator/base.py
Normal file
158
ldm/dream/generator/base.py
Normal file
@ -0,0 +1,158 @@
|
||||
'''
|
||||
Base class for ldm.dream.generator.*
|
||||
including img2img, txt2img, and inpaint
|
||||
'''
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
from tqdm import tqdm, trange
|
||||
from PIL import Image
|
||||
from einops import rearrange, repeat
|
||||
from pytorch_lightning import seed_everything
|
||||
from ldm.dream.devices import choose_autocast_device
|
||||
|
||||
downsampling = 8
|
||||
|
||||
class Generator():
|
||||
def __init__(self,model):
|
||||
self.model = model
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self,prompt,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
|
||||
|
||||
def set_variation(self, seed, variation_amount, with_variations):
|
||||
self.seed = seed
|
||||
self.variation_amount = variation_amount
|
||||
self.with_variations = with_variations
|
||||
|
||||
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None,
|
||||
**kwargs):
|
||||
device_type,scope = choose_autocast_device(self.model.device)
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
init_image = init_image,
|
||||
width = width,
|
||||
height = height,
|
||||
step_callback = step_callback,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
results = []
|
||||
seed = seed if seed else self.new_seed()
|
||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||
with scope(device_type), self.model.ema_scope():
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
x_T = None
|
||||
if self.variation_amount > 0:
|
||||
seed_everything(seed)
|
||||
target_noise = self.get_noise(width,height)
|
||||
x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
|
||||
elif initial_noise is not None:
|
||||
# i.e. we specified particular variations
|
||||
x_T = initial_noise
|
||||
else:
|
||||
seed_everything(seed)
|
||||
if self.model.device.type == 'mps':
|
||||
x_T = self.get_noise(width,height)
|
||||
|
||||
# make_image will do the equivalent of get_noise itself
|
||||
image = make_image(x_T)
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed)
|
||||
seed = self.new_seed()
|
||||
return results
|
||||
|
||||
def sample_to_image(self,samples):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
x_samples = self.model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
if len(x_samples) != 1:
|
||||
raise Exception(
|
||||
f'>> expected to get a single image, but got {len(x_samples)}')
|
||||
x_sample = 255.0 * rearrange(
|
||||
x_samples[0].cpu().numpy(), 'c h w -> h w c'
|
||||
)
|
||||
return Image.fromarray(x_sample.astype(np.uint8))
|
||||
|
||||
def generate_initial_noise(self, seed, width, height):
|
||||
initial_noise = None
|
||||
if self.variation_amount > 0 or len(self.with_variations) > 0:
|
||||
# use fixed initial noise plus random noise per iteration
|
||||
seed_everything(seed)
|
||||
initial_noise = self.get_noise(width,height)
|
||||
for v_seed, v_weight in self.with_variations:
|
||||
seed = v_seed
|
||||
seed_everything(seed)
|
||||
next_noise = self.get_noise(width,height)
|
||||
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
|
||||
if self.variation_amount > 0:
|
||||
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
||||
seed = random.randrange(0,np.iinfo(np.uint32).max)
|
||||
return (seed, initial_noise)
|
||||
else:
|
||||
return (seed, None)
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
"""
|
||||
Returns a tensor filled with random numbers, either form a normal distribution
|
||||
(txt2img) or from the latent image (img2img, inpaint)
|
||||
"""
|
||||
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
||||
|
||||
def new_seed(self):
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return self.seed
|
||||
|
||||
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
'''
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
'''
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(self.model.device)
|
||||
|
||||
return v2
|
||||
|
72
ldm/dream/generator/img2img.py
Normal file
72
ldm/dream/generator/img2img.py
Normal file
@ -0,0 +1,72 @@
|
||||
'''
|
||||
ldm.dream.generator.txt2img descends from ldm.dream.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.dream.devices import choose_autocast_device
|
||||
from ldm.dream.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self,model):
|
||||
super().__init__(model)
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,strength,step_callback=None,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
|
||||
# PLMS sampler not supported yet, so ignore previous sampler
|
||||
if not isinstance(sampler,DDIMSampler):
|
||||
print(
|
||||
f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler"
|
||||
)
|
||||
sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
device_type,scope = choose_autocast_device(self.model.device)
|
||||
with scope(device_type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
self.init_latent,
|
||||
torch.tensor([t_enc]).to(self.model.device),
|
||||
noise=x_T
|
||||
)
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
)
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
init_latent = self.init_latent
|
||||
assert init_latent is not None,'call to get_noise() when init_latent not set'
|
||||
if device.type == 'mps':
|
||||
return torch.randn_like(init_latent, device='cpu').to(device)
|
||||
else:
|
||||
return torch.randn_like(init_latent, device=device)
|
76
ldm/dream/generator/inpaint.py
Normal file
76
ldm/dream/generator/inpaint.py
Normal file
@ -0,0 +1,76 @@
|
||||
'''
|
||||
ldm.dream.generator.inpaint descends from ldm.dream.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import rearrange, repeat
|
||||
from ldm.dream.devices import choose_autocast_device
|
||||
from ldm.dream.generator.img2img import Img2Img
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
class Inpaint(Img2Img):
|
||||
def __init__(self,model):
|
||||
super().__init__(model)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,init_mask,strength,
|
||||
step_callback=None,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and
|
||||
the initial image + mask. Return value depends on the seed at
|
||||
the time you call it. kwargs are 'init_latent' and 'strength'
|
||||
"""
|
||||
|
||||
init_mask = init_mask[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
|
||||
init_mask = repeat(init_mask, '1 ... -> b ...', b=1)
|
||||
|
||||
# PLMS sampler not supported yet, so ignore previous sampler
|
||||
if not isinstance(sampler,DDIMSampler):
|
||||
print(
|
||||
f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler"
|
||||
)
|
||||
sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
device_type,scope = choose_autocast_device(self.model.device)
|
||||
with scope(device_type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c = conditioning
|
||||
|
||||
print(f">> target t_enc is {t_enc} steps")
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
self.init_latent,
|
||||
torch.tensor([t_enc]).to(self.model.device),
|
||||
noise=x_T
|
||||
)
|
||||
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
mask = init_mask,
|
||||
init_latent = self.init_latent
|
||||
)
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
|
||||
|
61
ldm/dream/generator/txt2img.py
Normal file
61
ldm/dream/generator/txt2img.py
Normal file
@ -0,0 +1,61 @@
|
||||
'''
|
||||
ldm.dream.generator.txt2img inherits from ldm.dream.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.dream.generator.base import Generator
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self,model):
|
||||
super().__init__(model)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,step_callback=None,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
uc, c = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
shape = [
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
]
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x_T,
|
||||
conditioning = c,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback
|
||||
)
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
if device.type == 'mps':
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device=device)
|
@ -59,10 +59,16 @@ class PromptFormatter:
|
||||
switches.append(f'-H{opt.height or t2i.height}')
|
||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||
switches.append(f'-A{opt.sampler_name or t2i.sampler_name}')
|
||||
# to do: put model name into the t2i object
|
||||
# switches.append(f'--model{t2i.model_name}')
|
||||
if opt.invert_mask:
|
||||
switches.append(f'--invert_mask')
|
||||
if opt.seamless or t2i.seamless:
|
||||
switches.append(f'--seamless')
|
||||
if opt.init_img:
|
||||
switches.append(f'-I{opt.init_img}')
|
||||
if opt.mask:
|
||||
switches.append(f'-M{opt.mask}')
|
||||
if opt.fit:
|
||||
switches.append(f'--fit')
|
||||
if opt.strength and opt.init_img is not None:
|
||||
|
@ -22,7 +22,7 @@ class Completer:
|
||||
def complete(self, text, state):
|
||||
buffer = readline.get_line_buffer()
|
||||
|
||||
if text.startswith(('-I', '--init_img')):
|
||||
if text.startswith(('-I', '--init_img','-M','--init_mask')):
|
||||
return self._path_completions(text, state, ('.png','.jpg','.jpeg'))
|
||||
|
||||
if buffer.strip().endswith('cd') or text.startswith(('.', '/')):
|
||||
@ -48,10 +48,15 @@ class Completer:
|
||||
|
||||
def _path_completions(self, text, state, extensions):
|
||||
# get the path so far
|
||||
# TODO: replace this mess with a regular expression match
|
||||
if text.startswith('-I'):
|
||||
path = text.replace('-I', '', 1).lstrip()
|
||||
elif text.startswith('--init_img='):
|
||||
path = text.replace('--init_img=', '', 1).lstrip()
|
||||
elif text.startswith('--init_mask='):
|
||||
path = text.replace('--init_mask=', '', 1).lstrip()
|
||||
elif text.startswith('-M'):
|
||||
path = text.replace('-M', '', 1).lstrip()
|
||||
else:
|
||||
path = text
|
||||
|
||||
@ -94,6 +99,7 @@ if readline_available:
|
||||
'--grid','-g',
|
||||
'--individual','-i',
|
||||
'--init_img','-I',
|
||||
'--init_mask','-M',
|
||||
'--strength','-f',
|
||||
'--variants','-v',
|
||||
'--outdir','-o',
|
||||
|
@ -144,7 +144,7 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
# and don't bother with the last one, since it'll render anyway
|
||||
nonlocal step_index
|
||||
if progress_images and step % 5 == 0 and step < steps - 1:
|
||||
image = self.model._sample_to_image(sample)
|
||||
image = self.model.sample_to_image(sample)
|
||||
name = f'{prefix}.{seed}.{step_index}.png'
|
||||
metadata = f'{prompt} -S{seed} [intermediate]'
|
||||
path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
|
||||
|
642
ldm/generate.py
Normal file
642
ldm/generate.py
Normal file
@ -0,0 +1,642 @@
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
|
||||
# Derived from source code carrying the following copyrights
|
||||
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import transformers
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image, ImageOps
|
||||
from torch import nn
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
from ldm.dream.image_util import InitImageResizer
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
from ldm.dream.conditioning import get_uc_and_c
|
||||
|
||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||
|
||||
Example Usage:
|
||||
|
||||
from ldm.generate import Generate
|
||||
|
||||
# Create an object with default values
|
||||
gr = Generate(model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
|
||||
config = <path> // configs/stable-diffusion/v1-inference.yaml
|
||||
iterations = <integer> // how many times to run the sampling (1)
|
||||
steps = <integer> // 50
|
||||
seed = <integer> // current system time
|
||||
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
grid = <boolean> // false
|
||||
width = <integer> // image width, multiple of 64 (512)
|
||||
height = <integer> // image height, multiple of 64 (512)
|
||||
cfg_scale = <float> // condition-free guidance scale (7.5)
|
||||
)
|
||||
|
||||
# do the slow model initialization
|
||||
gr.load_model()
|
||||
|
||||
# Do the fast inference & image generation. Any options passed here
|
||||
# override the default values assigned during class initialization
|
||||
# Will call load_model() if the model was not previously loaded and so
|
||||
# may be slow at first.
|
||||
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
|
||||
results = gr.prompt2png(prompt = "an astronaut riding a horse",
|
||||
outdir = "./outputs/samples",
|
||||
iterations = 3)
|
||||
|
||||
for row in results:
|
||||
print(f'filename={row[0]}')
|
||||
print(f'seed ={row[1]}')
|
||||
|
||||
# Same thing, but using an initial image.
|
||||
results = gr.prompt2png(prompt = "an astronaut riding a horse",
|
||||
outdir = "./outputs/,
|
||||
iterations = 3,
|
||||
init_img = "./sketches/horse+rider.png")
|
||||
|
||||
for row in results:
|
||||
print(f'filename={row[0]}')
|
||||
print(f'seed ={row[1]}')
|
||||
|
||||
# Same thing, but we return a series of Image objects, which lets you manipulate them,
|
||||
# combine them, and save them under arbitrary names
|
||||
|
||||
results = gr.prompt2image(prompt = "an astronaut riding a horse"
|
||||
outdir = "./outputs/")
|
||||
for row in results:
|
||||
im = row[0]
|
||||
seed = row[1]
|
||||
im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png')
|
||||
im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg')
|
||||
|
||||
Note that the old txt2img() and img2img() calls are deprecated but will
|
||||
still work.
|
||||
"""
|
||||
|
||||
|
||||
class Generate:
|
||||
"""Generate class
|
||||
Stores default values for multiple configuration items
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iterations = 1,
|
||||
steps = 50,
|
||||
cfg_scale = 7.5,
|
||||
weights = 'models/ldm/stable-diffusion-v1/model.ckpt',
|
||||
config = 'configs/stable-diffusion/v1-inference.yaml',
|
||||
grid = False,
|
||||
width = 512,
|
||||
height = 512,
|
||||
sampler_name = 'k_lms',
|
||||
ddim_eta = 0.0, # deterministic
|
||||
precision = 'autocast',
|
||||
full_precision = False,
|
||||
strength = 0.75, # default in scripts/img2img.py
|
||||
seamless = False,
|
||||
embedding_path = None,
|
||||
device_type = 'cuda',
|
||||
):
|
||||
self.iterations = iterations
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.steps = steps
|
||||
self.cfg_scale = cfg_scale
|
||||
self.weights = weights
|
||||
self.config = config
|
||||
self.sampler_name = sampler_name
|
||||
self.grid = grid
|
||||
self.ddim_eta = ddim_eta
|
||||
self.precision = precision
|
||||
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
|
||||
self.strength = strength
|
||||
self.seamless = seamless
|
||||
self.embedding_path = embedding_path
|
||||
self.device_type = device_type
|
||||
self.model = None # empty for now
|
||||
self.sampler = None
|
||||
self.device = None
|
||||
self.generators = {}
|
||||
self.base_generator = None
|
||||
self.seed = None
|
||||
|
||||
if device_type == 'cuda' and not torch.cuda.is_available():
|
||||
device_type = choose_torch_device()
|
||||
print(">> cuda not available, using device", device_type)
|
||||
self.device = torch.device(device_type)
|
||||
|
||||
# for VRAM usage statistics
|
||||
device_type = choose_torch_device()
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
def prompt2png(self, prompt, outdir, **kwargs):
|
||||
"""
|
||||
Takes a prompt and an output directory, writes out the requested number
|
||||
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
|
||||
Optional named arguments are the same as those passed to Generate and prompt2image()
|
||||
"""
|
||||
results = self.prompt2image(prompt, **kwargs)
|
||||
pngwriter = PngWriter(outdir)
|
||||
prefix = pngwriter.unique_prefix()
|
||||
outputs = []
|
||||
for image, seed in results:
|
||||
name = f'{prefix}.{seed}.png'
|
||||
path = pngwriter.save_image_and_prompt_to_png(
|
||||
image, f'{prompt} -S{seed}', name)
|
||||
outputs.append([path, seed])
|
||||
return outputs
|
||||
|
||||
def txt2img(self, prompt, **kwargs):
|
||||
outdir = kwargs.pop('outdir', 'outputs/img-samples')
|
||||
return self.prompt2png(prompt, outdir, **kwargs)
|
||||
|
||||
def img2img(self, prompt, **kwargs):
|
||||
outdir = kwargs.pop('outdir', 'outputs/img-samples')
|
||||
assert (
|
||||
'init_img' in kwargs
|
||||
), 'call to img2img() must include the init_img argument'
|
||||
return self.prompt2png(prompt, outdir, **kwargs)
|
||||
|
||||
def prompt2image(
|
||||
self,
|
||||
# these are common
|
||||
prompt,
|
||||
iterations = None,
|
||||
steps = None,
|
||||
seed = None,
|
||||
cfg_scale = None,
|
||||
ddim_eta = None,
|
||||
skip_normalize = False,
|
||||
image_callback = None,
|
||||
step_callback = None,
|
||||
width = None,
|
||||
height = None,
|
||||
sampler_name = None,
|
||||
seamless = False,
|
||||
log_tokenization= False,
|
||||
with_variations = None,
|
||||
variation_amount = 0.0,
|
||||
# these are specific to img2img
|
||||
init_img = None,
|
||||
mask = None,
|
||||
invert_mask = False,
|
||||
fit = False,
|
||||
strength = None,
|
||||
# these are specific to GFPGAN/ESRGAN
|
||||
gfpgan_strength= 0,
|
||||
save_original = False,
|
||||
upscale = None,
|
||||
**args,
|
||||
): # eat up additional cruft
|
||||
"""
|
||||
ldm.prompt2image() is the common entry point for txt2img() and img2img()
|
||||
It takes the following arguments:
|
||||
prompt // prompt string (no default)
|
||||
iterations // iterations (1); image count=iterations
|
||||
steps // refinement steps per iteration
|
||||
seed // seed for random number generator
|
||||
width // width of image, in multiples of 64 (512)
|
||||
height // height of image, in multiples of 64 (512)
|
||||
cfg_scale // how strongly the prompt influences the image (7.5) (must be >1)
|
||||
seamless // whether the generated image should tile
|
||||
init_img // path to an initial image
|
||||
mask // path to an initial image mask for inpainting
|
||||
invert_mask // paint over opaque areas, retain transparent areas
|
||||
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||
step_callback // a function or method that will be called each step
|
||||
image_callback // a function or method that will be called each time an image is generated
|
||||
with_variations // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation
|
||||
variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image)
|
||||
|
||||
To use the step callback, define a function that receives two arguments:
|
||||
- Image GPU data
|
||||
- The step number
|
||||
|
||||
To use the image callback, define a function of method that receives two arguments, an Image object
|
||||
and the seed. You can then do whatever you like with the image, including converting it to
|
||||
different formats and manipulating it. For example:
|
||||
|
||||
def process_image(image,seed):
|
||||
image.save(f{'images/seed.png'})
|
||||
|
||||
The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code
|
||||
to create the requested output directory, select a unique informative name for each image, and
|
||||
write the prompt into the PNG metadata.
|
||||
"""
|
||||
# TODO: convert this into a getattr() loop
|
||||
steps = steps or self.steps
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
seamless = seamless or self.seamless
|
||||
cfg_scale = cfg_scale or self.cfg_scale
|
||||
ddim_eta = ddim_eta or self.ddim_eta
|
||||
iterations = iterations or self.iterations
|
||||
strength = strength or self.strength
|
||||
self.seed = seed
|
||||
self.log_tokenization = log_tokenization
|
||||
with_variations = [] if with_variations is None else with_variations
|
||||
|
||||
model = (
|
||||
self.load_model()
|
||||
) # will instantiate the model or return it from cache
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
|
||||
|
||||
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
||||
assert (
|
||||
0.0 < strength < 1.0
|
||||
), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'
|
||||
assert (
|
||||
0.0 <= variation_amount <= 1.0
|
||||
), '-v --variation_amount must be in [0.0, 1.0]'
|
||||
|
||||
# check this logic - doesn't look right
|
||||
if len(with_variations) > 0 or variation_amount > 1.0:
|
||||
assert seed is not None,\
|
||||
'seed must be specified when using with_variations'
|
||||
if variation_amount == 0.0:
|
||||
assert iterations == 1,\
|
||||
'when using --with_variations, multiple iterations are only possible when using --variation_amount'
|
||||
assert all(0 <= weight <= 1 for _, weight in with_variations),\
|
||||
f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}'
|
||||
|
||||
width, height, _ = self._resolution_check(width, height, log=True)
|
||||
|
||||
if sampler_name and (sampler_name != self.sampler_name):
|
||||
self.sampler_name = sampler_name
|
||||
self._set_sampler()
|
||||
|
||||
tic = time.time()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
results = list()
|
||||
init_image = None
|
||||
init_mask_image = None
|
||||
|
||||
try:
|
||||
uc, c = get_uc_and_c(
|
||||
prompt, model=self.model,
|
||||
skip_normalize=skip_normalize,
|
||||
log_tokens=self.log_tokenization
|
||||
)
|
||||
|
||||
if mask and not init_img:
|
||||
raise AssertionError('If mask path is provided, initial image path should be provided as well')
|
||||
|
||||
if mask and init_img:
|
||||
init_image,size1 = self._load_img(init_img, width, height,fit=fit)
|
||||
init_image.to(self.device)
|
||||
init_mask_image,size2 = self._load_img_mask(mask, width, height,fit=fit, invert=invert_mask)
|
||||
init_mask_image.to(self.device)
|
||||
assert size1==size2,f"for inpainting, the initial image and its mask must be identical sizes, instead got {size1} vs {size2}"
|
||||
generator = self._make_inpaint()
|
||||
elif init_img: # little bit of repeated code here, but makes logic clearer
|
||||
init_image,_ = self._load_img(init_img, width, height, fit=fit)
|
||||
init_image.to(self.device)
|
||||
generator = self._make_img2img()
|
||||
else:
|
||||
generator = self._make_txt2img()
|
||||
|
||||
generator.set_variation(self.seed, variation_amount, with_variations)
|
||||
results = generator.generate(
|
||||
prompt,
|
||||
iterations = iterations,
|
||||
seed = self.seed,
|
||||
sampler = self.sampler,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
conditioning = (uc,c),
|
||||
ddim_eta = ddim_eta,
|
||||
image_callback = image_callback, # called after the final image is generated
|
||||
step_callback = step_callback, # called after each intermediate image is generated
|
||||
width = width,
|
||||
height = height,
|
||||
init_image = init_image, # notice that init_image is different from init_img
|
||||
init_mask = init_mask_image,
|
||||
strength = strength
|
||||
)
|
||||
|
||||
if upscale is not None or gfpgan_strength > 0:
|
||||
self.upscale_and_reconstruct(results,
|
||||
upscale = upscale,
|
||||
strength = gfpgan_strength,
|
||||
save_original = save_original,
|
||||
image_callback = image_callback)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print('*interrupted*')
|
||||
print(
|
||||
'>> Partial results will be returned; if --grid was requested, nothing will be returned.'
|
||||
)
|
||||
except RuntimeError as e:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Are you sure your system has an adequate GPU?')
|
||||
|
||||
toc = time.time()
|
||||
print('>> Usage stats:')
|
||||
print(
|
||||
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
|
||||
)
|
||||
print(
|
||||
f'>> Max VRAM used for this generation:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
if self.session_peakmem:
|
||||
self.session_peakmem = max(
|
||||
self.session_peakmem, torch.cuda.max_memory_allocated()
|
||||
)
|
||||
print(
|
||||
f'>> Max VRAM used since script start: ',
|
||||
'%4.2fG' % (self.session_peakmem / 1e9),
|
||||
)
|
||||
return results
|
||||
|
||||
def _make_img2img(self):
|
||||
if not self.generators.get('img2img'):
|
||||
from ldm.dream.generator.img2img import Img2Img
|
||||
self.generators['img2img'] = Img2Img(self.model)
|
||||
return self.generators['img2img']
|
||||
|
||||
def _make_txt2img(self):
|
||||
if not self.generators.get('txt2img'):
|
||||
from ldm.dream.generator.txt2img import Txt2Img
|
||||
self.generators['txt2img'] = Txt2Img(self.model)
|
||||
return self.generators['txt2img']
|
||||
|
||||
def _make_inpaint(self):
|
||||
if not self.generators.get('inpaint'):
|
||||
from ldm.dream.generator.inpaint import Inpaint
|
||||
self.generators['inpaint'] = Inpaint(self.model)
|
||||
return self.generators['inpaint']
|
||||
|
||||
def load_model(self):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if self.model is None:
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
try:
|
||||
config = OmegaConf.load(self.config)
|
||||
model = self._load_model_from_config(config, self.weights)
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(
|
||||
self.embedding_path, self.full_precision
|
||||
)
|
||||
self.model = model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
self.model.cond_stage_model.device = self.device
|
||||
except AttributeError as e:
|
||||
print(f'>> Error loading model. {str(e)}', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
raise SystemExit from e
|
||||
|
||||
self._set_sampler()
|
||||
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m._orig_padding_mode = m.padding_mode
|
||||
|
||||
return self.model
|
||||
|
||||
def upscale_and_reconstruct(self,
|
||||
image_list,
|
||||
upscale = None,
|
||||
strength = 0.0,
|
||||
save_original = False,
|
||||
image_callback = None):
|
||||
try:
|
||||
if upscale is not None:
|
||||
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
|
||||
if strength > 0:
|
||||
from ldm.gfpgan.gfpgan_tools import run_gfpgan
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
return
|
||||
|
||||
for r in image_list:
|
||||
image, seed = r
|
||||
try:
|
||||
if upscale is not None:
|
||||
if len(upscale) < 2:
|
||||
upscale.append(0.75)
|
||||
image = real_esrgan_upscale(
|
||||
image,
|
||||
upscale[1],
|
||||
int(upscale[0]),
|
||||
seed,
|
||||
)
|
||||
if strength > 0:
|
||||
image = run_gfpgan(
|
||||
image, strength, seed, 1
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'
|
||||
)
|
||||
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, upscaled=True)
|
||||
else:
|
||||
r[0] = image
|
||||
|
||||
# to help WebGUI - front end to generator util function
|
||||
def sample_to_image(self,samples):
|
||||
return self._sample_to_image(samples)
|
||||
|
||||
def _sample_to_image(self,samples):
|
||||
if not self.base_generator:
|
||||
from ldm.dream.generator import Generator
|
||||
self.base_generator = Generator(self.model)
|
||||
return self.base_generator.sample_to_image(samples)
|
||||
|
||||
def _set_sampler(self):
|
||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||
if self.sampler_name == 'plms':
|
||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'ddim':
|
||||
self.sampler = DDIMSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'k_dpm_2_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'dpm_2_ancestral', device=self.device
|
||||
)
|
||||
elif self.sampler_name == 'k_dpm_2':
|
||||
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
|
||||
elif self.sampler_name == 'k_euler_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'euler_ancestral', device=self.device
|
||||
)
|
||||
elif self.sampler_name == 'k_euler':
|
||||
self.sampler = KSampler(self.model, 'euler', device=self.device)
|
||||
elif self.sampler_name == 'k_heun':
|
||||
self.sampler = KSampler(self.model, 'heun', device=self.device)
|
||||
elif self.sampler_name == 'k_lms':
|
||||
self.sampler = KSampler(self.model, 'lms', device=self.device)
|
||||
else:
|
||||
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
|
||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||
|
||||
print(msg)
|
||||
|
||||
def _load_model_from_config(self, config, ckpt):
|
||||
print(f'>> Loading model from {ckpt}')
|
||||
pl_sd = torch.load(ckpt, map_location='cpu')
|
||||
sd = pl_sd['state_dict']
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
if self.full_precision:
|
||||
print(
|
||||
'>> Using slower but more accurate full-precision math (--full_precision)'
|
||||
)
|
||||
else:
|
||||
print(
|
||||
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
|
||||
)
|
||||
model.half()
|
||||
return model
|
||||
|
||||
def _load_img(self, path, width, height, fit=False):
|
||||
assert os.path.exists(path), f'>> {path}: File not found'
|
||||
|
||||
with Image.open(path) as img:
|
||||
image = img.convert('RGB')
|
||||
print(
|
||||
f'>> loaded input image of size {image.width}x{image.height} from {path}'
|
||||
)
|
||||
if fit:
|
||||
image = self._fit_image(image,(width,height))
|
||||
else:
|
||||
image = self._squeeze_image(image)
|
||||
|
||||
size = image.size
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
image = 2.0 * image - 1.0
|
||||
return image.to(self.device),size
|
||||
|
||||
def _load_img_mask(self, path, width, height, fit=False, invert=False):
|
||||
assert os.path.exists(path), f'>> {path}: File not found'
|
||||
|
||||
image = Image.open(path)
|
||||
print(
|
||||
f'>> loaded input mask of size {image.width}x{image.height} from {path}'
|
||||
)
|
||||
|
||||
if fit:
|
||||
image = self._fit_image(image,(width,height))
|
||||
else:
|
||||
image = self._squeeze_image(image)
|
||||
|
||||
# convert into a black/white mask
|
||||
image = self._mask_to_image(image,invert)
|
||||
image = image.convert('RGB')
|
||||
size = image.size
|
||||
|
||||
# not quite sure what's going on here. It is copied from basunjindal's implementation
|
||||
# image = image.resize((64, 64), resample=Image.Resampling.LANCZOS)
|
||||
# BUG: We need to use the model's downsample factor rather than hardcoding "8"
|
||||
from ldm.dream.generator.base import downsampling
|
||||
image = image.resize((size[0]//downsampling, size[1]//downsampling), resample=Image.Resampling.LANCZOS)
|
||||
image = np.array(image)
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return image.to(self.device),size
|
||||
|
||||
# The mask is expected to have the region to be inpainted
|
||||
# with alpha transparency. It converts it into a black/white
|
||||
# image with the transparent part black.
|
||||
def _mask_to_image(self, init_mask, invert=False) -> Image:
|
||||
if self._has_transparency(init_mask):
|
||||
# Obtain the mask from the transparency channel
|
||||
mask = Image.new(mode="L", size=init_mask.size, color=255)
|
||||
mask.putdata(init_mask.getdata(band=3))
|
||||
if invert:
|
||||
mask = ImageOps.invert(mask)
|
||||
return mask
|
||||
else:
|
||||
print(f'>> No transparent pixels in this image. Will paint across entire image.')
|
||||
return Image.new(mode="L", size=mask.size, color=0)
|
||||
|
||||
def _has_transparency(self,image):
|
||||
if image.info.get("transparency", None) is not None:
|
||||
return True
|
||||
if image.mode == "P":
|
||||
transparent = image.info.get("transparency", -1)
|
||||
for _, index in image.getcolors():
|
||||
if index == transparent:
|
||||
return True
|
||||
elif image.mode == "RGBA":
|
||||
extrema = image.getextrema()
|
||||
if extrema[3][0] < 255:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _squeeze_image(self,image):
|
||||
x,y,resize_needed = self._resolution_check(image.width,image.height)
|
||||
if resize_needed:
|
||||
return InitImageResizer(image).resize(x,y)
|
||||
return image
|
||||
|
||||
|
||||
def _fit_image(self,image,max_dimensions):
|
||||
w,h = max_dimensions
|
||||
print(
|
||||
f'>> image will be resized to fit inside a box {w}x{h} in size.'
|
||||
)
|
||||
if image.width > image.height:
|
||||
h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
|
||||
elif image.height > image.width:
|
||||
w = None # ditto for w
|
||||
else:
|
||||
pass
|
||||
image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally
|
||||
print(
|
||||
f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}'
|
||||
)
|
||||
return image
|
||||
|
||||
def _resolution_check(self, width, height, log=False):
|
||||
resize_needed = False
|
||||
w, h = map(
|
||||
lambda x: x - x % 64, (width, height)
|
||||
) # resize to integer multiple of 64
|
||||
if h != height or w != width:
|
||||
if log:
|
||||
print(
|
||||
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
|
||||
)
|
||||
height = h
|
||||
width = w
|
||||
resize_needed = True
|
||||
|
||||
if (width * height) > (self.width * self.height):
|
||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||
|
||||
return width, height, resize_needed
|
||||
|
||||
|
@ -13,8 +13,8 @@ opt = arg_parser.parse_args()
|
||||
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
|
||||
gfpgan_model_exists = os.path.isfile(model_path)
|
||||
|
||||
def _run_gfpgan(image, strength, prompt, seed, upsampler_scale=4):
|
||||
print(f'>> GFPGAN - Restoring Faces: {prompt} : seed:{seed}')
|
||||
def run_gfpgan(image, strength, seed, upsampler_scale=4):
|
||||
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
||||
gfpgan = None
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
@ -127,9 +127,9 @@ def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
|
||||
return bg_upsampler
|
||||
|
||||
|
||||
def real_esrgan_upscale(image, strength, upsampler_scale, prompt, seed):
|
||||
def real_esrgan_upscale(image, strength, upsampler_scale, seed):
|
||||
print(
|
||||
f'>> Real-ESRGAN Upscaling: {prompt} : seed:{seed} : scale:{upsampler_scale}x'
|
||||
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
|
@ -171,6 +171,7 @@ class DDIMSampler(object):
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
# This routine gets called from img2img
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(
|
||||
self,
|
||||
@ -270,6 +271,7 @@ class DDIMSampler(object):
|
||||
|
||||
return img, intermediates
|
||||
|
||||
# This routine gets called from ddim_sampling() and decode()
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(
|
||||
self,
|
||||
@ -372,14 +374,16 @@ class DDIMSampler(object):
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
):
|
||||
|
||||
timesteps = (
|
||||
@ -395,6 +399,8 @@ class DDIMSampler(object):
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
x0 = init_latent
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full(
|
||||
@ -403,6 +409,14 @@ class DDIMSampler(object):
|
||||
device=x_latent.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
xdec_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
|
||||
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
@ -412,6 +426,7 @@ class DDIMSampler(object):
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
|
||||
if img_callback:
|
||||
img_callback(x_dec, i)
|
||||
|
||||
|
@ -13,7 +13,7 @@ def exists(val):
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
@ -45,18 +45,19 @@ class GEGLU(nn.Module):
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -73,9 +74,7 @@ def zero_module(module):
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
@ -83,28 +82,17 @@ class LinearAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(
|
||||
qkv,
|
||||
'b (qkv heads c) h w -> qkv b heads c (h w)',
|
||||
heads=self.heads,
|
||||
qkv=3,
|
||||
)
|
||||
k = k.softmax(dim=-1)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(
|
||||
out,
|
||||
'b heads c (h w) -> b (heads c) h w',
|
||||
heads=self.heads,
|
||||
h=h,
|
||||
w=w,
|
||||
)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@ -114,18 +102,26 @@ class SpatialSelfAttention(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
@ -135,12 +131,12 @@ class SpatialSelfAttention(nn.Module):
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
b,c,h,w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
@ -150,18 +146,16 @@ class SpatialSelfAttention(nn.Module):
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
return x+h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0
|
||||
):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
@ -169,7 +163,8 @@ class CrossAttention(nn.Module):
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
@ -179,59 +174,43 @@ class CrossAttention(nn.Module):
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)
|
||||
)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (8, 4096, 40)
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
del mask
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
# attention, what we cannot get enough of, by halves
|
||||
sim[4:] = sim[4:].softmax(dim=-1)
|
||||
sim[:4] = sim[:4].softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
sim = einsum('b i j, b j d -> b i d', sim, v)
|
||||
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(sim)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is a self-attention
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
) # is self-attn if context is none
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(
|
||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
)
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = x.contiguous() if x.device.type == 'mps' else x
|
||||
@ -249,43 +228,29 @@ class SpatialTransformer(nn.Module):
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
):
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
)
|
||||
for d in range(depth)
|
||||
]
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(
|
||||
inner_dim, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
)
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
|
863
ldm/simplet2i.py
863
ldm/simplet2i.py
@ -1,856 +1,13 @@
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
'''
|
||||
This module is provided for backward compatibility with the
|
||||
original (hasty) API.
|
||||
|
||||
# Derived from source code carrying the following copyrights
|
||||
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
||||
Please use ldm.generate instead.
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
import traceback
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
from torchvision.utils import make_grid
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
import transformers
|
||||
import time
|
||||
import re
|
||||
import sys
|
||||
from ldm.generate import Generate
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
from ldm.dream.image_util import InitImageResizer
|
||||
from ldm.dream.devices import choose_autocast_device, choose_torch_device
|
||||
|
||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||
|
||||
Example Usage:
|
||||
|
||||
from ldm.simplet2i import T2I
|
||||
|
||||
# Create an object with default values
|
||||
t2i = T2I(model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
|
||||
config = <path> // configs/stable-diffusion/v1-inference.yaml
|
||||
iterations = <integer> // how many times to run the sampling (1)
|
||||
steps = <integer> // 50
|
||||
seed = <integer> // current system time
|
||||
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
grid = <boolean> // false
|
||||
width = <integer> // image width, multiple of 64 (512)
|
||||
height = <integer> // image height, multiple of 64 (512)
|
||||
cfg_scale = <float> // unconditional guidance scale (7.5)
|
||||
)
|
||||
|
||||
# do the slow model initialization
|
||||
t2i.load_model()
|
||||
|
||||
# Do the fast inference & image generation. Any options passed here
|
||||
# override the default values assigned during class initialization
|
||||
# Will call load_model() if the model was not previously loaded and so
|
||||
# may be slow at first.
|
||||
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
|
||||
results = t2i.prompt2png(prompt = "an astronaut riding a horse",
|
||||
outdir = "./outputs/samples",
|
||||
iterations = 3)
|
||||
|
||||
for row in results:
|
||||
print(f'filename={row[0]}')
|
||||
print(f'seed ={row[1]}')
|
||||
|
||||
# Same thing, but using an initial image.
|
||||
results = t2i.prompt2png(prompt = "an astronaut riding a horse",
|
||||
outdir = "./outputs/,
|
||||
iterations = 3,
|
||||
init_img = "./sketches/horse+rider.png")
|
||||
|
||||
for row in results:
|
||||
print(f'filename={row[0]}')
|
||||
print(f'seed ={row[1]}')
|
||||
|
||||
# Same thing, but we return a series of Image objects, which lets you manipulate them,
|
||||
# combine them, and save them under arbitrary names
|
||||
|
||||
results = t2i.prompt2image(prompt = "an astronaut riding a horse"
|
||||
outdir = "./outputs/")
|
||||
for row in results:
|
||||
im = row[0]
|
||||
seed = row[1]
|
||||
im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png')
|
||||
im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg')
|
||||
|
||||
Note that the old txt2img() and img2img() calls are deprecated but will
|
||||
still work.
|
||||
"""
|
||||
|
||||
|
||||
class T2I:
|
||||
"""T2I class
|
||||
Attributes
|
||||
----------
|
||||
model
|
||||
config
|
||||
iterations
|
||||
steps
|
||||
seed
|
||||
sampler_name
|
||||
width
|
||||
height
|
||||
cfg_scale
|
||||
latent_channels
|
||||
downsampling_factor
|
||||
precision
|
||||
strength
|
||||
seamless
|
||||
embedding_path
|
||||
|
||||
The vast majority of these arguments default to reasonable values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iterations=1,
|
||||
steps=50,
|
||||
seed=None,
|
||||
cfg_scale=7.5,
|
||||
weights='models/ldm/stable-diffusion-v1/model.ckpt',
|
||||
config='configs/stable-diffusion/v1-inference.yaml',
|
||||
grid=False,
|
||||
width=512,
|
||||
height=512,
|
||||
sampler_name='k_lms',
|
||||
latent_channels=4,
|
||||
downsampling_factor=8,
|
||||
ddim_eta=0.0, # deterministic
|
||||
precision='autocast',
|
||||
full_precision=False,
|
||||
strength=0.75, # default in scripts/img2img.py
|
||||
seamless=False,
|
||||
embedding_path=None,
|
||||
device_type = 'cuda',
|
||||
# just to keep track of this parameter when regenerating prompt
|
||||
# needs to be replaced when new configuration system implemented.
|
||||
latent_diffusion_weights=False,
|
||||
):
|
||||
self.iterations = iterations
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.steps = steps
|
||||
self.cfg_scale = cfg_scale
|
||||
self.weights = weights
|
||||
self.config = config
|
||||
self.sampler_name = sampler_name
|
||||
self.latent_channels = latent_channels
|
||||
self.downsampling_factor = downsampling_factor
|
||||
self.grid = grid
|
||||
self.ddim_eta = ddim_eta
|
||||
self.precision = precision
|
||||
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
|
||||
self.strength = strength
|
||||
self.seamless = seamless
|
||||
self.embedding_path = embedding_path
|
||||
self.device_type = device_type
|
||||
self.model = None # empty for now
|
||||
self.sampler = None
|
||||
self.device = None
|
||||
self.latent_diffusion_weights = latent_diffusion_weights
|
||||
|
||||
if device_type == 'cuda' and not torch.cuda.is_available():
|
||||
device_type = choose_torch_device()
|
||||
print(">> cuda not available, using device", device_type)
|
||||
self.device = torch.device(device_type)
|
||||
|
||||
# for VRAM usage statistics
|
||||
device_type = choose_torch_device()
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
|
||||
|
||||
if seed is None:
|
||||
self.seed = self._new_seed()
|
||||
else:
|
||||
self.seed = seed
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
def prompt2png(self, prompt, outdir, **kwargs):
|
||||
"""
|
||||
Takes a prompt and an output directory, writes out the requested number
|
||||
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
|
||||
Optional named arguments are the same as those passed to T2I and prompt2image()
|
||||
"""
|
||||
results = self.prompt2image(prompt, **kwargs)
|
||||
pngwriter = PngWriter(outdir)
|
||||
prefix = pngwriter.unique_prefix()
|
||||
outputs = []
|
||||
for image, seed in results:
|
||||
name = f'{prefix}.{seed}.png'
|
||||
path = pngwriter.save_image_and_prompt_to_png(
|
||||
image, f'{prompt} -S{seed}', name)
|
||||
outputs.append([path, seed])
|
||||
return outputs
|
||||
|
||||
def txt2img(self, prompt, **kwargs):
|
||||
outdir = kwargs.pop('outdir', 'outputs/img-samples')
|
||||
return self.prompt2png(prompt, outdir, **kwargs)
|
||||
|
||||
def img2img(self, prompt, **kwargs):
|
||||
outdir = kwargs.pop('outdir', 'outputs/img-samples')
|
||||
assert (
|
||||
'init_img' in kwargs
|
||||
), 'call to img2img() must include the init_img argument'
|
||||
return self.prompt2png(prompt, outdir, **kwargs)
|
||||
|
||||
def prompt2image(
|
||||
self,
|
||||
# these are common
|
||||
prompt,
|
||||
iterations = None,
|
||||
steps = None,
|
||||
seed = None,
|
||||
cfg_scale = None,
|
||||
ddim_eta = None,
|
||||
skip_normalize = False,
|
||||
image_callback = None,
|
||||
step_callback = None,
|
||||
width = None,
|
||||
height = None,
|
||||
seamless = False,
|
||||
# these are specific to img2img
|
||||
init_img = None,
|
||||
fit = False,
|
||||
strength = None,
|
||||
gfpgan_strength= 0,
|
||||
save_original = False,
|
||||
upscale = None,
|
||||
sampler_name = None,
|
||||
log_tokenization= False,
|
||||
with_variations = None,
|
||||
variation_amount = 0.0,
|
||||
**args,
|
||||
): # eat up additional cruft
|
||||
"""
|
||||
ldm.prompt2image() is the common entry point for txt2img() and img2img()
|
||||
It takes the following arguments:
|
||||
prompt // prompt string (no default)
|
||||
iterations // iterations (1); image count=iterations
|
||||
steps // refinement steps per iteration
|
||||
seed // seed for random number generator
|
||||
width // width of image, in multiples of 64 (512)
|
||||
height // height of image, in multiples of 64 (512)
|
||||
cfg_scale // how strongly the prompt influences the image (7.5) (must be >1)
|
||||
seamless // whether the generated image should tile
|
||||
init_img // path to an initial image - its dimensions override width and height
|
||||
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||
step_callback // a function or method that will be called each step
|
||||
image_callback // a function or method that will be called each time an image is generated
|
||||
with_variations // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation
|
||||
variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image)
|
||||
|
||||
To use the step callback, define a function that receives two arguments:
|
||||
- Image GPU data
|
||||
- The step number
|
||||
|
||||
To use the image callback, define a function of method that receives two arguments, an Image object
|
||||
and the seed. You can then do whatever you like with the image, including converting it to
|
||||
different formats and manipulating it. For example:
|
||||
|
||||
def process_image(image,seed):
|
||||
image.save(f{'images/seed.png'})
|
||||
|
||||
The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code
|
||||
to create the requested output directory, select a unique informative name for each image, and
|
||||
write the prompt into the PNG metadata.
|
||||
"""
|
||||
# TODO: convert this into a getattr() loop
|
||||
steps = steps or self.steps
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
seamless = seamless or self.seamless
|
||||
cfg_scale = cfg_scale or self.cfg_scale
|
||||
ddim_eta = ddim_eta or self.ddim_eta
|
||||
iterations = iterations or self.iterations
|
||||
strength = strength or self.strength
|
||||
self.log_tokenization = log_tokenization
|
||||
with_variations = [] if with_variations is None else with_variations
|
||||
|
||||
model = (
|
||||
self.load_model()
|
||||
) # will instantiate the model or return it from cache
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
|
||||
|
||||
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
||||
assert (
|
||||
0.0 <= strength <= 1.0
|
||||
), 'can only work with strength in [0.0, 1.0]'
|
||||
assert (
|
||||
0.0 <= variation_amount <= 1.0
|
||||
), '-v --variation_amount must be in [0.0, 1.0]'
|
||||
|
||||
if len(with_variations) > 0 or variation_amount > 0.0:
|
||||
assert seed is not None,\
|
||||
'seed must be specified when using with_variations'
|
||||
if variation_amount == 0.0:
|
||||
assert iterations == 1,\
|
||||
'when using --with_variations, multiple iterations are only possible when using --variation_amount'
|
||||
assert all(0 <= weight <= 1 for _, weight in with_variations),\
|
||||
f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}'
|
||||
|
||||
seed = seed or self.seed
|
||||
width, height, _ = self._resolution_check(width, height, log=True)
|
||||
|
||||
# TODO: - Check if this is still necessary to run on M1 devices.
|
||||
# - Move code into ldm.dream.devices to live alongside other
|
||||
# special-hardware casing code.
|
||||
if self.precision == 'autocast' and torch.cuda.is_available():
|
||||
scope = autocast
|
||||
else:
|
||||
scope = nullcontext
|
||||
|
||||
if sampler_name and (sampler_name != self.sampler_name):
|
||||
self.sampler_name = sampler_name
|
||||
self._set_sampler()
|
||||
|
||||
tic = time.time()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
results = list()
|
||||
|
||||
try:
|
||||
if init_img:
|
||||
assert os.path.exists(init_img), f'{init_img}: File not found'
|
||||
init_image = self._load_img(init_img, width, height, fit).to(self.device)
|
||||
with scope(self.device.type):
|
||||
init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
make_image = self._img2img(
|
||||
prompt,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
ddim_eta=ddim_eta,
|
||||
skip_normalize=skip_normalize,
|
||||
init_latent=init_latent,
|
||||
strength=strength,
|
||||
callback=step_callback,
|
||||
)
|
||||
else:
|
||||
init_latent = None
|
||||
make_image = self._txt2img(
|
||||
prompt,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
ddim_eta=ddim_eta,
|
||||
skip_normalize=skip_normalize,
|
||||
width=width,
|
||||
height=height,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
initial_noise = None
|
||||
if variation_amount > 0 or len(with_variations) > 0:
|
||||
# use fixed initial noise plus random noise per iteration
|
||||
seed_everything(seed)
|
||||
initial_noise = self._get_noise(init_latent,width,height)
|
||||
for v_seed, v_weight in with_variations:
|
||||
seed = v_seed
|
||||
seed_everything(seed)
|
||||
next_noise = self._get_noise(init_latent,width,height)
|
||||
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
|
||||
if variation_amount > 0:
|
||||
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
||||
seed = random.randrange(0,np.iinfo(np.uint32).max)
|
||||
|
||||
device_type = choose_autocast_device(self.device)
|
||||
with scope(device_type), self.model.ema_scope():
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
x_T = None
|
||||
if variation_amount > 0:
|
||||
seed_everything(seed)
|
||||
target_noise = self._get_noise(init_latent,width,height)
|
||||
x_T = self.slerp(variation_amount, initial_noise, target_noise)
|
||||
elif initial_noise is not None:
|
||||
# i.e. we specified particular variations
|
||||
x_T = initial_noise
|
||||
else:
|
||||
seed_everything(seed)
|
||||
if self.device.type == 'mps':
|
||||
x_T = self._get_noise(init_latent,width,height)
|
||||
# make_image will do the equivalent of get_noise itself
|
||||
image = make_image(x_T)
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed)
|
||||
seed = self._new_seed()
|
||||
|
||||
if upscale is not None or gfpgan_strength > 0:
|
||||
for result in results:
|
||||
image, seed = result
|
||||
try:
|
||||
if upscale is not None:
|
||||
from ldm.gfpgan.gfpgan_tools import (
|
||||
real_esrgan_upscale,
|
||||
)
|
||||
if len(upscale) < 2:
|
||||
upscale.append(0.75)
|
||||
image = real_esrgan_upscale(
|
||||
image,
|
||||
upscale[1],
|
||||
int(upscale[0]),
|
||||
prompt,
|
||||
seed,
|
||||
)
|
||||
if gfpgan_strength > 0:
|
||||
from ldm.gfpgan.gfpgan_tools import _run_gfpgan
|
||||
|
||||
image = _run_gfpgan(
|
||||
image, gfpgan_strength, prompt, seed, 1
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f'>> Error running RealESRGAN - Your image was not upscaled.\n{e}'
|
||||
)
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, upscaled=True)
|
||||
else: # no callback passed, so we simply replace old image with rescaled one
|
||||
result[0] = image
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print('*interrupted*')
|
||||
print(
|
||||
'>> Partial results will be returned; if --grid was requested, nothing will be returned.'
|
||||
)
|
||||
except RuntimeError as e:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Are you sure your system has an adequate NVIDIA GPU?')
|
||||
|
||||
toc = time.time()
|
||||
print('>> Usage stats:')
|
||||
print(
|
||||
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
|
||||
)
|
||||
print(
|
||||
f'>> Max VRAM used for this generation:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
if self.session_peakmem:
|
||||
self.session_peakmem = max(
|
||||
self.session_peakmem, torch.cuda.max_memory_allocated()
|
||||
)
|
||||
print(
|
||||
f'>> Max VRAM used since script start: ',
|
||||
'%4.2fG' % (self.session_peakmem / 1e9),
|
||||
)
|
||||
return results
|
||||
|
||||
@torch.no_grad()
|
||||
def _txt2img(
|
||||
self,
|
||||
prompt,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
skip_normalize,
|
||||
width,
|
||||
height,
|
||||
callback,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
|
||||
sampler = self.sampler
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
uc, c = self._get_uc_and_c(prompt, skip_normalize)
|
||||
shape = [
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
]
|
||||
samples, _ = sampler.sample(
|
||||
batch_size=1,
|
||||
S=steps,
|
||||
x_T=x_T,
|
||||
conditioning=c,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
img_callback=callback
|
||||
)
|
||||
return self._sample_to_image(samples)
|
||||
return make_image
|
||||
|
||||
@torch.no_grad()
|
||||
def _img2img(
|
||||
self,
|
||||
prompt,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
skip_normalize,
|
||||
init_latent,
|
||||
strength,
|
||||
callback, # Currently not implemented for img2img
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
|
||||
# PLMS sampler not supported yet, so ignore previous sampler
|
||||
if self.sampler_name != 'ddim':
|
||||
print(
|
||||
f">> sampler '{self.sampler_name}' is not yet supported. Using DDIM sampler"
|
||||
)
|
||||
sampler = DDIMSampler(self.model, device=self.device)
|
||||
else:
|
||||
sampler = self.sampler
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
uc, c = self._get_uc_and_c(prompt, skip_normalize)
|
||||
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
init_latent,
|
||||
torch.tensor([t_enc]).to(self.device),
|
||||
noise=x_T
|
||||
)
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback=callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
)
|
||||
return self._sample_to_image(samples)
|
||||
return make_image
|
||||
|
||||
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
|
||||
def _get_uc_and_c(self, prompt, skip_normalize):
|
||||
|
||||
uc = self.model.get_learned_conditioning([''])
|
||||
|
||||
# get weighted sub-prompts
|
||||
weighted_subprompts = T2I._split_weighted_subprompts(
|
||||
prompt, skip_normalize)
|
||||
|
||||
if len(weighted_subprompts) > 1:
|
||||
# i dont know if this is correct.. but it works
|
||||
c = torch.zeros_like(uc)
|
||||
# normalize each "sub prompt" and add it
|
||||
for subprompt, weight in weighted_subprompts:
|
||||
self._log_tokenization(subprompt)
|
||||
c = torch.add(
|
||||
c,
|
||||
self.model.get_learned_conditioning([subprompt]),
|
||||
alpha=weight,
|
||||
)
|
||||
else: # just standard 1 prompt
|
||||
self._log_tokenization(prompt)
|
||||
c = self.model.get_learned_conditioning([prompt])
|
||||
return (uc, c)
|
||||
|
||||
def _sample_to_image(self, samples):
|
||||
x_samples = self.model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
if len(x_samples) != 1:
|
||||
raise Exception(
|
||||
f'>> expected to get a single image, but got {len(x_samples)}')
|
||||
x_sample = 255.0 * rearrange(
|
||||
x_samples[0].cpu().numpy(), 'c h w -> h w c'
|
||||
)
|
||||
return Image.fromarray(x_sample.astype(np.uint8))
|
||||
|
||||
def _new_seed(self):
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return self.seed
|
||||
|
||||
def load_model(self):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if self.model is None:
|
||||
seed_everything(self.seed)
|
||||
try:
|
||||
config = OmegaConf.load(self.config)
|
||||
model = self._load_model_from_config(config, self.weights)
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(
|
||||
self.embedding_path, self.full_precision
|
||||
)
|
||||
self.model = model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
self.model.cond_stage_model.device = self.device
|
||||
except AttributeError as e:
|
||||
print(f'>> Error loading model. {str(e)}', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
raise SystemExit from e
|
||||
|
||||
self._set_sampler()
|
||||
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m._orig_padding_mode = m.padding_mode
|
||||
|
||||
return self.model
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def _get_noise(self,init_latent,width,height):
|
||||
if init_latent is not None:
|
||||
if self.device.type == 'mps':
|
||||
return torch.randn_like(init_latent, device='cpu').to(self.device)
|
||||
else:
|
||||
return torch.randn_like(init_latent, device=self.device)
|
||||
else:
|
||||
if self.device.type == 'mps':
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device='cpu').to(self.device)
|
||||
else:
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device=self.device)
|
||||
|
||||
def _set_sampler(self):
|
||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||
if self.sampler_name == 'plms':
|
||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'ddim':
|
||||
self.sampler = DDIMSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'k_dpm_2_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'dpm_2_ancestral', device=self.device
|
||||
)
|
||||
elif self.sampler_name == 'k_dpm_2':
|
||||
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
|
||||
elif self.sampler_name == 'k_euler_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'euler_ancestral', device=self.device
|
||||
)
|
||||
elif self.sampler_name == 'k_euler':
|
||||
self.sampler = KSampler(self.model, 'euler', device=self.device)
|
||||
elif self.sampler_name == 'k_heun':
|
||||
self.sampler = KSampler(self.model, 'heun', device=self.device)
|
||||
elif self.sampler_name == 'k_lms':
|
||||
self.sampler = KSampler(self.model, 'lms', device=self.device)
|
||||
else:
|
||||
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
|
||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||
|
||||
print(msg)
|
||||
|
||||
def _load_model_from_config(self, config, ckpt):
|
||||
print(f'>> Loading model from {ckpt}')
|
||||
pl_sd = torch.load(ckpt, map_location='cpu')
|
||||
# if "global_step" in pl_sd:
|
||||
# print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd['state_dict']
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
if self.full_precision:
|
||||
print(
|
||||
'Using slower but more accurate full-precision math (--full_precision)'
|
||||
)
|
||||
else:
|
||||
print(
|
||||
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
|
||||
)
|
||||
model.half()
|
||||
return model
|
||||
|
||||
def _load_img(self, path, width, height, fit=False):
|
||||
with Image.open(path) as img:
|
||||
image = img.convert('RGB')
|
||||
print(
|
||||
f'>> loaded input image of size {image.width}x{image.height} from {path}'
|
||||
)
|
||||
|
||||
# The logic here is:
|
||||
# 1. If "fit" is true, then the image will be fit into the bounding box defined
|
||||
# by width and height. It will do this in a way that preserves the init image's
|
||||
# aspect ratio while preventing letterboxing. This means that if there is
|
||||
# leftover horizontal space after rescaling the image to fit in the bounding box,
|
||||
# the generated image's width will be reduced to the rescaled init image's width.
|
||||
# Similarly for the vertical space.
|
||||
# 2. Otherwise, if "fit" is false, then the image will be scaled, preserving its
|
||||
# aspect ratio, to the nearest multiple of 64. Large images may generate an
|
||||
# unexpected OOM error.
|
||||
if fit:
|
||||
image = self._fit_image(image,(width,height))
|
||||
else:
|
||||
image = self._squeeze_image(image)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
def _squeeze_image(self,image):
|
||||
x,y,resize_needed = self._resolution_check(image.width,image.height)
|
||||
if resize_needed:
|
||||
return InitImageResizer(image).resize(x,y)
|
||||
return image
|
||||
|
||||
|
||||
def _fit_image(self,image,max_dimensions):
|
||||
w,h = max_dimensions
|
||||
print(
|
||||
f'>> image will be resized to fit inside a box {w}x{h} in size.'
|
||||
)
|
||||
if image.width > image.height:
|
||||
h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
|
||||
elif image.height > image.width:
|
||||
w = None # ditto for w
|
||||
else:
|
||||
pass
|
||||
image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally
|
||||
print(
|
||||
f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}'
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
# TO DO: Move this and related weighted subprompt code into its own module.
|
||||
def _split_weighted_subprompts(text, skip_normalize=False):
|
||||
"""
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile("""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
(?: # non-capture group
|
||||
:+ # match one or more ':' characters
|
||||
(?P<weight> # capture group for 'weight'
|
||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||
)? # end weight capture group, make optional
|
||||
\s* # strip spaces after weight
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""", re.VERBOSE)
|
||||
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
||||
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||||
equal_weight = 1 / len(parsed_prompts)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||
|
||||
# shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
def _log_tokenization(self, text):
|
||||
if not self.log_tokenization:
|
||||
return
|
||||
tokens = self.model.cond_stage_model.tokenizer._tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', ' ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < self.model.cond_stage_model.max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")
|
||||
|
||||
def _resolution_check(self, width, height, log=False):
|
||||
resize_needed = False
|
||||
w, h = map(
|
||||
lambda x: x - x % 64, (width, height)
|
||||
) # resize to integer multiple of 64
|
||||
if h != height or w != width:
|
||||
if log:
|
||||
print(
|
||||
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
|
||||
)
|
||||
height = h
|
||||
width = w
|
||||
resize_needed = True
|
||||
|
||||
if (width * height) > (self.width * self.height):
|
||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||
|
||||
return width, height, resize_needed
|
||||
|
||||
|
||||
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
'''
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
'''
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(self.device)
|
||||
|
||||
return v2
|
||||
class T2I(Generate):
|
||||
def __init__(self,**kwargs):
|
||||
print(f'>> The ldm.simplet2i module is deprecated. Use ldm.generate instead. It is a drop-in replacement.')
|
||||
super().__init__(kwargs)
|
||||
|
@ -40,7 +40,7 @@ def main():
|
||||
print('* Initializing, be patient...\n')
|
||||
sys.path.append('.')
|
||||
from pytorch_lightning import logging
|
||||
from ldm.simplet2i import T2I
|
||||
from ldm.generate import Generate
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
@ -52,19 +52,18 @@ def main():
|
||||
# defaults passed on the command line.
|
||||
# additional parameters will be added (or overriden) during
|
||||
# the user input loop
|
||||
t2i = T2I(
|
||||
width=width,
|
||||
height=height,
|
||||
sampler_name=opt.sampler_name,
|
||||
weights=weights,
|
||||
full_precision=opt.full_precision,
|
||||
config=config,
|
||||
grid = opt.grid,
|
||||
t2i = Generate(
|
||||
width = width,
|
||||
height = height,
|
||||
sampler_name = opt.sampler_name,
|
||||
weights = weights,
|
||||
full_precision = opt.full_precision,
|
||||
config = config,
|
||||
grid = opt.grid,
|
||||
# this is solely for recreating the prompt
|
||||
latent_diffusion_weights=opt.laion400m,
|
||||
seamless=opt.seamless,
|
||||
embedding_path=opt.embedding_path,
|
||||
device_type=opt.device
|
||||
seamless = opt.seamless,
|
||||
embedding_path = opt.embedding_path,
|
||||
device_type = opt.device
|
||||
)
|
||||
|
||||
# make sure the output directory exists
|
||||
@ -567,6 +566,17 @@ def create_cmd_parser():
|
||||
type=str,
|
||||
help='Path to input image for img2img mode (supersedes width and height)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-M',
|
||||
'--mask',
|
||||
type=str,
|
||||
help='Path to inpainting mask; transparent areas will be painted over',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--invert_mask',
|
||||
action='store_true',
|
||||
help='Invert the inpainting mask; opaque areas will be painted over',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-T',
|
||||
'-fit',
|
||||
|
Loading…
Reference in New Issue
Block a user