diff --git a/ldm/dream/conditioning.py b/ldm/dream/conditioning.py new file mode 100644 index 0000000000..dfa108985a --- /dev/null +++ b/ldm/dream/conditioning.py @@ -0,0 +1,96 @@ +''' +This module handles the generation of the conditioning tensors, including management of +weighted subprompts. + +Useful function exports: + +get_uc_and_c() get the conditioned and unconditioned latent +split_weighted_subpromopts() split subprompts, normalize and weight them +log_tokenization() print out colour-coded tokens and warn if truncated + +''' +import re +import torch + +def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): + uc = model.get_learned_conditioning(['']) + + # get weighted sub-prompts + weighted_subprompts = split_weighted_subprompts( + prompt, skip_normalize + ) + + if len(weighted_subprompts) > 1: + # i dont know if this is correct.. but it works + c = torch.zeros_like(uc) + # normalize each "sub prompt" and add it + for subprompt, weight in weighted_subprompts: + log_tokenization(subprompt, model, log_tokens) + c = torch.add( + c, + model.get_learned_conditioning([subprompt]), + alpha=weight, + ) + else: # just standard 1 prompt + log_tokenization(prompt, model, log_tokens) + c = model.get_learned_conditioning([prompt]) + return (uc, c) + +def split_weighted_subprompts(text, skip_normalize=False)->list: + """ + grabs all text up to the first occurrence of ':' + uses the grabbed text as a sub-prompt, and takes the value following ':' as weight + if ':' has no value defined, defaults to 1.0 + repeats until no text remaining + """ + prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P # capture group for 'weight' + -?\d+(?:\.\d+)? # match positive or negative integer or decimal number + )? # end weight capture group, make optional + \s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group + """, re.VERBOSE) + parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( + match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] + if skip_normalize: + return parsed_prompts + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + if weight_sum == 0: + print( + "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") + equal_weight = 1 / len(parsed_prompts) + return [(x[0], equal_weight) for x in parsed_prompts] + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] + +# shows how the prompt is tokenized +# usually tokens have '' to indicate end-of-word, +# but for readability it has been replaced with ' ' +def log_tokenization(text, model, log=False): + if not log: + return + tokens = model.cond_stage_model.tokenizer._tokenize(text) + tokenized = "" + discarded = "" + usedTokens = 0 + totalTokens = len(tokens) + for i in range(0, totalTokens): + token = tokens[i].replace('', ' ') + # alternate color + s = (usedTokens % 6) + 1 + if i < model.cond_stage_model.max_length: + tokenized = tokenized + f"\x1b[0;3{s};40m{token}" + usedTokens += 1 + else: # over max token length + discarded = discarded + f"\x1b[0;3{s};40m{token}" + print(f"\n>> Tokens ({usedTokens}):\n{tokenized}\x1b[0m") + if discarded != "": + print( + f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" + ) diff --git a/ldm/dream/devices.py b/ldm/dream/devices.py index 7a205f6616..90bc9e97dd 100644 --- a/ldm/dream/devices.py +++ b/ldm/dream/devices.py @@ -1,4 +1,6 @@ import torch +from torch import autocast +from contextlib import contextmanager, nullcontext def choose_torch_device() -> str: '''Convenience routine for guessing which GPU device to run model on''' @@ -8,10 +10,11 @@ def choose_torch_device() -> str: return 'mps' return 'cpu' -def choose_autocast_device(device) -> str: +def choose_autocast_device(device): '''Returns an autocast compatible device from a torch device''' device_type = device.type # this returns 'mps' on M1 # autocast only supports cuda or cpu - if device_type not in ('cuda','cpu'): - return 'cpu' - return device_type + if device_type in ('cuda','cpu'): + return device_type,autocast + else: + return 'cpu',nullcontext diff --git a/ldm/dream/generator/__init__.py b/ldm/dream/generator/__init__.py new file mode 100644 index 0000000000..b48e6e19c8 --- /dev/null +++ b/ldm/dream/generator/__init__.py @@ -0,0 +1,4 @@ +''' +Initialization file for the ldm.dream.generator package +''' +from .base import Generator diff --git a/ldm/dream/generator/base.py b/ldm/dream/generator/base.py new file mode 100644 index 0000000000..9bed3df719 --- /dev/null +++ b/ldm/dream/generator/base.py @@ -0,0 +1,158 @@ +''' +Base class for ldm.dream.generator.* +including img2img, txt2img, and inpaint +''' +import torch +import numpy as np +import random +from tqdm import tqdm, trange +from PIL import Image +from einops import rearrange, repeat +from pytorch_lightning import seed_everything +from ldm.dream.devices import choose_autocast_device + +downsampling = 8 + +class Generator(): + def __init__(self,model): + self.model = model + self.seed = None + self.latent_channels = model.channels + self.downsampling_factor = downsampling # BUG: should come from model or config + self.variation_amount = 0 + self.with_variations = [] + + # this is going to be overridden in img2img.py, txt2img.py and inpaint.py + def get_make_image(self,prompt,**kwargs): + """ + Returns a function returning an image derived from the prompt and the initial image + Return value depends on the seed at the time you call it + """ + raise NotImplementedError("image_iterator() must be implemented in a descendent class") + + def set_variation(self, seed, variation_amount, with_variations): + self.seed = seed + self.variation_amount = variation_amount + self.with_variations = with_variations + + def generate(self,prompt,init_image,width,height,iterations=1,seed=None, + image_callback=None, step_callback=None, + **kwargs): + device_type,scope = choose_autocast_device(self.model.device) + make_image = self.get_make_image( + prompt, + init_image = init_image, + width = width, + height = height, + step_callback = step_callback, + **kwargs + ) + + results = [] + seed = seed if seed else self.new_seed() + seed, initial_noise = self.generate_initial_noise(seed, width, height) + with scope(device_type), self.model.ema_scope(): + for n in trange(iterations, desc='Generating'): + x_T = None + if self.variation_amount > 0: + seed_everything(seed) + target_noise = self.get_noise(width,height) + x_T = self.slerp(self.variation_amount, initial_noise, target_noise) + elif initial_noise is not None: + # i.e. we specified particular variations + x_T = initial_noise + else: + seed_everything(seed) + if self.model.device.type == 'mps': + x_T = self.get_noise(width,height) + + # make_image will do the equivalent of get_noise itself + image = make_image(x_T) + results.append([image, seed]) + if image_callback is not None: + image_callback(image, seed) + seed = self.new_seed() + return results + + def sample_to_image(self,samples): + """ + Returns a function returning an image derived from the prompt and the initial image + Return value depends on the seed at the time you call it + """ + x_samples = self.model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + if len(x_samples) != 1: + raise Exception( + f'>> expected to get a single image, but got {len(x_samples)}') + x_sample = 255.0 * rearrange( + x_samples[0].cpu().numpy(), 'c h w -> h w c' + ) + return Image.fromarray(x_sample.astype(np.uint8)) + + def generate_initial_noise(self, seed, width, height): + initial_noise = None + if self.variation_amount > 0 or len(self.with_variations) > 0: + # use fixed initial noise plus random noise per iteration + seed_everything(seed) + initial_noise = self.get_noise(width,height) + for v_seed, v_weight in self.with_variations: + seed = v_seed + seed_everything(seed) + next_noise = self.get_noise(width,height) + initial_noise = self.slerp(v_weight, initial_noise, next_noise) + if self.variation_amount > 0: + random.seed() # reset RNG to an actually random state, so we can get a random seed for variations + seed = random.randrange(0,np.iinfo(np.uint32).max) + return (seed, initial_noise) + else: + return (seed, None) + + # returns a tensor filled with random numbers from a normal distribution + def get_noise(self,width,height): + """ + Returns a tensor filled with random numbers, either form a normal distribution + (txt2img) or from the latent image (img2img, inpaint) + """ + raise NotImplementedError("get_noise() must be implemented in a descendent class") + + def new_seed(self): + self.seed = random.randrange(0, np.iinfo(np.uint32).max) + return self.seed + + def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995): + ''' + Spherical linear interpolation + Args: + t (float/np.ndarray): Float value between 0.0 and 1.0 + v0 (np.ndarray): Starting vector + v1 (np.ndarray): Final vector + DOT_THRESHOLD (float): Threshold for considering the two vectors as + colineal. Not recommended to alter this. + Returns: + v2 (np.ndarray): Interpolation vector between v0 and v1 + ''' + inputs_are_torch = False + if not isinstance(v0, np.ndarray): + inputs_are_torch = True + v0 = v0.detach().cpu().numpy() + if not isinstance(v1, np.ndarray): + inputs_are_torch = True + v1 = v1.detach().cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > DOT_THRESHOLD: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + if inputs_are_torch: + v2 = torch.from_numpy(v2).to(self.model.device) + + return v2 + diff --git a/ldm/dream/generator/img2img.py b/ldm/dream/generator/img2img.py new file mode 100644 index 0000000000..242912d0eb --- /dev/null +++ b/ldm/dream/generator/img2img.py @@ -0,0 +1,72 @@ +''' +ldm.dream.generator.txt2img descends from ldm.dream.generator +''' + +import torch +import numpy as np +from ldm.dream.devices import choose_autocast_device +from ldm.dream.generator.base import Generator +from ldm.models.diffusion.ddim import DDIMSampler + +class Img2Img(Generator): + def __init__(self,model): + super().__init__(model) + self.init_latent = None # by get_noise() + + @torch.no_grad() + def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, + conditioning,init_image,strength,step_callback=None,**kwargs): + """ + Returns a function returning an image derived from the prompt and the initial image + Return value depends on the seed at the time you call it. + """ + + # PLMS sampler not supported yet, so ignore previous sampler + if not isinstance(sampler,DDIMSampler): + print( + f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler" + ) + sampler = DDIMSampler(self.model, device=self.model.device) + + sampler.make_schedule( + ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False + ) + + device_type,scope = choose_autocast_device(self.model.device) + with scope(device_type): + self.init_latent = self.model.get_first_stage_encoding( + self.model.encode_first_stage(init_image) + ) # move to latent space + + t_enc = int(strength * steps) + uc, c = conditioning + + @torch.no_grad() + def make_image(x_T): + # encode (scaled latent) + z_enc = sampler.stochastic_encode( + self.init_latent, + torch.tensor([t_enc]).to(self.model.device), + noise=x_T + ) + # decode it + samples = sampler.decode( + z_enc, + c, + t_enc, + img_callback = step_callback, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + ) + return self.sample_to_image(samples) + + return make_image + + def get_noise(self,width,height): + device = self.model.device + init_latent = self.init_latent + assert init_latent is not None,'call to get_noise() when init_latent not set' + if device.type == 'mps': + return torch.randn_like(init_latent, device='cpu').to(device) + else: + return torch.randn_like(init_latent, device=device) diff --git a/ldm/dream/generator/inpaint.py b/ldm/dream/generator/inpaint.py new file mode 100644 index 0000000000..d70e64121f --- /dev/null +++ b/ldm/dream/generator/inpaint.py @@ -0,0 +1,76 @@ +''' +ldm.dream.generator.inpaint descends from ldm.dream.generator +''' + +import torch +import numpy as np +from einops import rearrange, repeat +from ldm.dream.devices import choose_autocast_device +from ldm.dream.generator.img2img import Img2Img +from ldm.models.diffusion.ddim import DDIMSampler + +class Inpaint(Img2Img): + def __init__(self,model): + super().__init__(model) + + @torch.no_grad() + def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, + conditioning,init_image,init_mask,strength, + step_callback=None,**kwargs): + """ + Returns a function returning an image derived from the prompt and + the initial image + mask. Return value depends on the seed at + the time you call it. kwargs are 'init_latent' and 'strength' + """ + + init_mask = init_mask[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0) + init_mask = repeat(init_mask, '1 ... -> b ...', b=1) + + # PLMS sampler not supported yet, so ignore previous sampler + if not isinstance(sampler,DDIMSampler): + print( + f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler" + ) + sampler = DDIMSampler(self.model, device=self.model.device) + + sampler.make_schedule( + ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False + ) + + device_type,scope = choose_autocast_device(self.model.device) + with scope(device_type): + self.init_latent = self.model.get_first_stage_encoding( + self.model.encode_first_stage(init_image) + ) # move to latent space + + t_enc = int(strength * steps) + uc, c = conditioning + + print(f">> target t_enc is {t_enc} steps") + + @torch.no_grad() + def make_image(x_T): + # encode (scaled latent) + z_enc = sampler.stochastic_encode( + self.init_latent, + torch.tensor([t_enc]).to(self.model.device), + noise=x_T + ) + + # decode it + samples = sampler.decode( + z_enc, + c, + t_enc, + img_callback = step_callback, + unconditional_guidance_scale = cfg_scale, + unconditional_conditioning = uc, + mask = init_mask, + init_latent = self.init_latent + ) + return self.sample_to_image(samples) + + return make_image + + + diff --git a/ldm/dream/generator/txt2img.py b/ldm/dream/generator/txt2img.py new file mode 100644 index 0000000000..d4cd25cb51 --- /dev/null +++ b/ldm/dream/generator/txt2img.py @@ -0,0 +1,61 @@ +''' +ldm.dream.generator.txt2img inherits from ldm.dream.generator +''' + +import torch +import numpy as np +from ldm.dream.generator.base import Generator + +class Txt2Img(Generator): + def __init__(self,model): + super().__init__(model) + + @torch.no_grad() + def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, + conditioning,width,height,step_callback=None,**kwargs): + """ + Returns a function returning an image derived from the prompt and the initial image + Return value depends on the seed at the time you call it + kwargs are 'width' and 'height' + """ + uc, c = conditioning + + @torch.no_grad() + def make_image(x_T): + shape = [ + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor, + ] + samples, _ = sampler.sample( + batch_size = 1, + S = steps, + x_T = x_T, + conditioning = c, + shape = shape, + verbose = False, + unconditional_guidance_scale = cfg_scale, + unconditional_conditioning = uc, + eta = ddim_eta, + img_callback = step_callback + ) + return self.sample_to_image(samples) + + return make_image + + + # returns a tensor filled with random numbers from a normal distribution + def get_noise(self,width,height): + device = self.model.device + if device.type == 'mps': + return torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device='cpu').to(device) + else: + return torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device=device) diff --git a/ldm/dream/pngwriter.py b/ldm/dream/pngwriter.py index b97cc1470c..4ba4812433 100644 --- a/ldm/dream/pngwriter.py +++ b/ldm/dream/pngwriter.py @@ -59,10 +59,16 @@ class PromptFormatter: switches.append(f'-H{opt.height or t2i.height}') switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') switches.append(f'-A{opt.sampler_name or t2i.sampler_name}') +# to do: put model name into the t2i object +# switches.append(f'--model{t2i.model_name}') + if opt.invert_mask: + switches.append(f'--invert_mask') if opt.seamless or t2i.seamless: switches.append(f'--seamless') if opt.init_img: switches.append(f'-I{opt.init_img}') + if opt.mask: + switches.append(f'-M{opt.mask}') if opt.fit: switches.append(f'--fit') if opt.strength and opt.init_img is not None: diff --git a/ldm/dream/readline.py b/ldm/dream/readline.py index 24a4493ad9..2aa8520acf 100644 --- a/ldm/dream/readline.py +++ b/ldm/dream/readline.py @@ -22,7 +22,7 @@ class Completer: def complete(self, text, state): buffer = readline.get_line_buffer() - if text.startswith(('-I', '--init_img')): + if text.startswith(('-I', '--init_img','-M','--init_mask')): return self._path_completions(text, state, ('.png','.jpg','.jpeg')) if buffer.strip().endswith('cd') or text.startswith(('.', '/')): @@ -48,10 +48,15 @@ class Completer: def _path_completions(self, text, state, extensions): # get the path so far + # TODO: replace this mess with a regular expression match if text.startswith('-I'): path = text.replace('-I', '', 1).lstrip() elif text.startswith('--init_img='): path = text.replace('--init_img=', '', 1).lstrip() + elif text.startswith('--init_mask='): + path = text.replace('--init_mask=', '', 1).lstrip() + elif text.startswith('-M'): + path = text.replace('-M', '', 1).lstrip() else: path = text @@ -94,6 +99,7 @@ if readline_available: '--grid','-g', '--individual','-i', '--init_img','-I', + '--init_mask','-M', '--strength','-f', '--variants','-v', '--outdir','-o', diff --git a/ldm/dream/server.py b/ldm/dream/server.py index 46af9df931..10cd7d722e 100644 --- a/ldm/dream/server.py +++ b/ldm/dream/server.py @@ -144,7 +144,7 @@ class DreamServer(BaseHTTPRequestHandler): # and don't bother with the last one, since it'll render anyway nonlocal step_index if progress_images and step % 5 == 0 and step < steps - 1: - image = self.model._sample_to_image(sample) + image = self.model.sample_to_image(sample) name = f'{prefix}.{seed}.{step_index}.png' metadata = f'{prompt} -S{seed} [intermediate]' path = step_writer.save_image_and_prompt_to_png(image, metadata, name) diff --git a/ldm/generate.py b/ldm/generate.py new file mode 100644 index 0000000000..9ba72c3676 --- /dev/null +++ b/ldm/generate.py @@ -0,0 +1,642 @@ +# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) + +# Derived from source code carrying the following copyrights +# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich +# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors + +import torch +import numpy as np +import random +import os +import time +import re +import sys +import traceback +import transformers + +from omegaconf import OmegaConf +from PIL import Image, ImageOps +from torch import nn +from pytorch_lightning import seed_everything + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.ksampler import KSampler +from ldm.dream.pngwriter import PngWriter +from ldm.dream.image_util import InitImageResizer +from ldm.dream.devices import choose_torch_device +from ldm.dream.conditioning import get_uc_and_c + +"""Simplified text to image API for stable diffusion/latent diffusion + +Example Usage: + +from ldm.generate import Generate + +# Create an object with default values +gr = Generate(model = // models/ldm/stable-diffusion-v1/model.ckpt + config = // configs/stable-diffusion/v1-inference.yaml + iterations = // how many times to run the sampling (1) + steps = // 50 + seed = // current system time + sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms + grid = // false + width = // image width, multiple of 64 (512) + height = // image height, multiple of 64 (512) + cfg_scale = // condition-free guidance scale (7.5) + ) + +# do the slow model initialization +gr.load_model() + +# Do the fast inference & image generation. Any options passed here +# override the default values assigned during class initialization +# Will call load_model() if the model was not previously loaded and so +# may be slow at first. +# The method returns a list of images. Each row of the list is a sub-list of [filename,seed] +results = gr.prompt2png(prompt = "an astronaut riding a horse", + outdir = "./outputs/samples", + iterations = 3) + +for row in results: + print(f'filename={row[0]}') + print(f'seed ={row[1]}') + +# Same thing, but using an initial image. +results = gr.prompt2png(prompt = "an astronaut riding a horse", + outdir = "./outputs/, + iterations = 3, + init_img = "./sketches/horse+rider.png") + +for row in results: + print(f'filename={row[0]}') + print(f'seed ={row[1]}') + +# Same thing, but we return a series of Image objects, which lets you manipulate them, +# combine them, and save them under arbitrary names + +results = gr.prompt2image(prompt = "an astronaut riding a horse" + outdir = "./outputs/") +for row in results: + im = row[0] + seed = row[1] + im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png') + im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg') + +Note that the old txt2img() and img2img() calls are deprecated but will +still work. +""" + + +class Generate: + """Generate class + Stores default values for multiple configuration items + """ + + def __init__( + self, + iterations = 1, + steps = 50, + cfg_scale = 7.5, + weights = 'models/ldm/stable-diffusion-v1/model.ckpt', + config = 'configs/stable-diffusion/v1-inference.yaml', + grid = False, + width = 512, + height = 512, + sampler_name = 'k_lms', + ddim_eta = 0.0, # deterministic + precision = 'autocast', + full_precision = False, + strength = 0.75, # default in scripts/img2img.py + seamless = False, + embedding_path = None, + device_type = 'cuda', + ): + self.iterations = iterations + self.width = width + self.height = height + self.steps = steps + self.cfg_scale = cfg_scale + self.weights = weights + self.config = config + self.sampler_name = sampler_name + self.grid = grid + self.ddim_eta = ddim_eta + self.precision = precision + self.full_precision = True if choose_torch_device() == 'mps' else full_precision + self.strength = strength + self.seamless = seamless + self.embedding_path = embedding_path + self.device_type = device_type + self.model = None # empty for now + self.sampler = None + self.device = None + self.generators = {} + self.base_generator = None + self.seed = None + + if device_type == 'cuda' and not torch.cuda.is_available(): + device_type = choose_torch_device() + print(">> cuda not available, using device", device_type) + self.device = torch.device(device_type) + + # for VRAM usage statistics + device_type = choose_torch_device() + self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None + transformers.logging.set_verbosity_error() + + def prompt2png(self, prompt, outdir, **kwargs): + """ + Takes a prompt and an output directory, writes out the requested number + of PNG files, and returns an array of [[filename,seed],[filename,seed]...] + Optional named arguments are the same as those passed to Generate and prompt2image() + """ + results = self.prompt2image(prompt, **kwargs) + pngwriter = PngWriter(outdir) + prefix = pngwriter.unique_prefix() + outputs = [] + for image, seed in results: + name = f'{prefix}.{seed}.png' + path = pngwriter.save_image_and_prompt_to_png( + image, f'{prompt} -S{seed}', name) + outputs.append([path, seed]) + return outputs + + def txt2img(self, prompt, **kwargs): + outdir = kwargs.pop('outdir', 'outputs/img-samples') + return self.prompt2png(prompt, outdir, **kwargs) + + def img2img(self, prompt, **kwargs): + outdir = kwargs.pop('outdir', 'outputs/img-samples') + assert ( + 'init_img' in kwargs + ), 'call to img2img() must include the init_img argument' + return self.prompt2png(prompt, outdir, **kwargs) + + def prompt2image( + self, + # these are common + prompt, + iterations = None, + steps = None, + seed = None, + cfg_scale = None, + ddim_eta = None, + skip_normalize = False, + image_callback = None, + step_callback = None, + width = None, + height = None, + sampler_name = None, + seamless = False, + log_tokenization= False, + with_variations = None, + variation_amount = 0.0, + # these are specific to img2img + init_img = None, + mask = None, + invert_mask = False, + fit = False, + strength = None, + # these are specific to GFPGAN/ESRGAN + gfpgan_strength= 0, + save_original = False, + upscale = None, + **args, + ): # eat up additional cruft + """ + ldm.prompt2image() is the common entry point for txt2img() and img2img() + It takes the following arguments: + prompt // prompt string (no default) + iterations // iterations (1); image count=iterations + steps // refinement steps per iteration + seed // seed for random number generator + width // width of image, in multiples of 64 (512) + height // height of image, in multiples of 64 (512) + cfg_scale // how strongly the prompt influences the image (7.5) (must be >1) + seamless // whether the generated image should tile + init_img // path to an initial image + mask // path to an initial image mask for inpainting + invert_mask // paint over opaque areas, retain transparent areas + strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely + gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely + ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) + step_callback // a function or method that will be called each step + image_callback // a function or method that will be called each time an image is generated + with_variations // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation + variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image) + + To use the step callback, define a function that receives two arguments: + - Image GPU data + - The step number + + To use the image callback, define a function of method that receives two arguments, an Image object + and the seed. You can then do whatever you like with the image, including converting it to + different formats and manipulating it. For example: + + def process_image(image,seed): + image.save(f{'images/seed.png'}) + + The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code + to create the requested output directory, select a unique informative name for each image, and + write the prompt into the PNG metadata. + """ + # TODO: convert this into a getattr() loop + steps = steps or self.steps + width = width or self.width + height = height or self.height + seamless = seamless or self.seamless + cfg_scale = cfg_scale or self.cfg_scale + ddim_eta = ddim_eta or self.ddim_eta + iterations = iterations or self.iterations + strength = strength or self.strength + self.seed = seed + self.log_tokenization = log_tokenization + with_variations = [] if with_variations is None else with_variations + + model = ( + self.load_model() + ) # will instantiate the model or return it from cache + + for m in model.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + m.padding_mode = 'circular' if seamless else m._orig_padding_mode + + assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0' + assert ( + 0.0 < strength < 1.0 + ), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0' + assert ( + 0.0 <= variation_amount <= 1.0 + ), '-v --variation_amount must be in [0.0, 1.0]' + + # check this logic - doesn't look right + if len(with_variations) > 0 or variation_amount > 1.0: + assert seed is not None,\ + 'seed must be specified when using with_variations' + if variation_amount == 0.0: + assert iterations == 1,\ + 'when using --with_variations, multiple iterations are only possible when using --variation_amount' + assert all(0 <= weight <= 1 for _, weight in with_variations),\ + f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}' + + width, height, _ = self._resolution_check(width, height, log=True) + + if sampler_name and (sampler_name != self.sampler_name): + self.sampler_name = sampler_name + self._set_sampler() + + tic = time.time() + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + + results = list() + init_image = None + init_mask_image = None + + try: + uc, c = get_uc_and_c( + prompt, model=self.model, + skip_normalize=skip_normalize, + log_tokens=self.log_tokenization + ) + + if mask and not init_img: + raise AssertionError('If mask path is provided, initial image path should be provided as well') + + if mask and init_img: + init_image,size1 = self._load_img(init_img, width, height,fit=fit) + init_image.to(self.device) + init_mask_image,size2 = self._load_img_mask(mask, width, height,fit=fit, invert=invert_mask) + init_mask_image.to(self.device) + assert size1==size2,f"for inpainting, the initial image and its mask must be identical sizes, instead got {size1} vs {size2}" + generator = self._make_inpaint() + elif init_img: # little bit of repeated code here, but makes logic clearer + init_image,_ = self._load_img(init_img, width, height, fit=fit) + init_image.to(self.device) + generator = self._make_img2img() + else: + generator = self._make_txt2img() + + generator.set_variation(self.seed, variation_amount, with_variations) + results = generator.generate( + prompt, + iterations = iterations, + seed = self.seed, + sampler = self.sampler, + steps = steps, + cfg_scale = cfg_scale, + conditioning = (uc,c), + ddim_eta = ddim_eta, + image_callback = image_callback, # called after the final image is generated + step_callback = step_callback, # called after each intermediate image is generated + width = width, + height = height, + init_image = init_image, # notice that init_image is different from init_img + init_mask = init_mask_image, + strength = strength + ) + + if upscale is not None or gfpgan_strength > 0: + self.upscale_and_reconstruct(results, + upscale = upscale, + strength = gfpgan_strength, + save_original = save_original, + image_callback = image_callback) + + except KeyboardInterrupt: + print('*interrupted*') + print( + '>> Partial results will be returned; if --grid was requested, nothing will be returned.' + ) + except RuntimeError as e: + print(traceback.format_exc(), file=sys.stderr) + print('>> Are you sure your system has an adequate GPU?') + + toc = time.time() + print('>> Usage stats:') + print( + f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic) + ) + print( + f'>> Max VRAM used for this generation:', + '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9), + ) + + if self.session_peakmem: + self.session_peakmem = max( + self.session_peakmem, torch.cuda.max_memory_allocated() + ) + print( + f'>> Max VRAM used since script start: ', + '%4.2fG' % (self.session_peakmem / 1e9), + ) + return results + + def _make_img2img(self): + if not self.generators.get('img2img'): + from ldm.dream.generator.img2img import Img2Img + self.generators['img2img'] = Img2Img(self.model) + return self.generators['img2img'] + + def _make_txt2img(self): + if not self.generators.get('txt2img'): + from ldm.dream.generator.txt2img import Txt2Img + self.generators['txt2img'] = Txt2Img(self.model) + return self.generators['txt2img'] + + def _make_inpaint(self): + if not self.generators.get('inpaint'): + from ldm.dream.generator.inpaint import Inpaint + self.generators['inpaint'] = Inpaint(self.model) + return self.generators['inpaint'] + + def load_model(self): + """Load and initialize the model from configuration variables passed at object creation time""" + if self.model is None: + seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) + try: + config = OmegaConf.load(self.config) + model = self._load_model_from_config(config, self.weights) + if self.embedding_path is not None: + model.embedding_manager.load( + self.embedding_path, self.full_precision + ) + self.model = model.to(self.device) + # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here + self.model.cond_stage_model.device = self.device + except AttributeError as e: + print(f'>> Error loading model. {str(e)}', file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + raise SystemExit from e + + self._set_sampler() + + for m in self.model.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + m._orig_padding_mode = m.padding_mode + + return self.model + + def upscale_and_reconstruct(self, + image_list, + upscale = None, + strength = 0.0, + save_original = False, + image_callback = None): + try: + if upscale is not None: + from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale + if strength > 0: + from ldm.gfpgan.gfpgan_tools import run_gfpgan + except (ModuleNotFoundError, ImportError): + print(traceback.format_exc(), file=sys.stderr) + print('>> You may need to install the ESRGAN and/or GFPGAN modules') + return + + for r in image_list: + image, seed = r + try: + if upscale is not None: + if len(upscale) < 2: + upscale.append(0.75) + image = real_esrgan_upscale( + image, + upscale[1], + int(upscale[0]), + seed, + ) + if strength > 0: + image = run_gfpgan( + image, strength, seed, 1 + ) + except Exception as e: + print( + f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}' + ) + + if image_callback is not None: + image_callback(image, seed, upscaled=True) + else: + r[0] = image + + # to help WebGUI - front end to generator util function + def sample_to_image(self,samples): + return self._sample_to_image(samples) + + def _sample_to_image(self,samples): + if not self.base_generator: + from ldm.dream.generator import Generator + self.base_generator = Generator(self.model) + return self.base_generator.sample_to_image(samples) + + def _set_sampler(self): + msg = f'>> Setting Sampler to {self.sampler_name}' + if self.sampler_name == 'plms': + self.sampler = PLMSSampler(self.model, device=self.device) + elif self.sampler_name == 'ddim': + self.sampler = DDIMSampler(self.model, device=self.device) + elif self.sampler_name == 'k_dpm_2_a': + self.sampler = KSampler( + self.model, 'dpm_2_ancestral', device=self.device + ) + elif self.sampler_name == 'k_dpm_2': + self.sampler = KSampler(self.model, 'dpm_2', device=self.device) + elif self.sampler_name == 'k_euler_a': + self.sampler = KSampler( + self.model, 'euler_ancestral', device=self.device + ) + elif self.sampler_name == 'k_euler': + self.sampler = KSampler(self.model, 'euler', device=self.device) + elif self.sampler_name == 'k_heun': + self.sampler = KSampler(self.model, 'heun', device=self.device) + elif self.sampler_name == 'k_lms': + self.sampler = KSampler(self.model, 'lms', device=self.device) + else: + msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms' + self.sampler = PLMSSampler(self.model, device=self.device) + + print(msg) + + def _load_model_from_config(self, config, ckpt): + print(f'>> Loading model from {ckpt}') + pl_sd = torch.load(ckpt, map_location='cpu') + sd = pl_sd['state_dict'] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + model.to(self.device) + model.eval() + if self.full_precision: + print( + '>> Using slower but more accurate full-precision math (--full_precision)' + ) + else: + print( + '>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.' + ) + model.half() + return model + + def _load_img(self, path, width, height, fit=False): + assert os.path.exists(path), f'>> {path}: File not found' + + with Image.open(path) as img: + image = img.convert('RGB') + print( + f'>> loaded input image of size {image.width}x{image.height} from {path}' + ) + if fit: + image = self._fit_image(image,(width,height)) + else: + image = self._squeeze_image(image) + + size = image.size + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + image = 2.0 * image - 1.0 + return image.to(self.device),size + + def _load_img_mask(self, path, width, height, fit=False, invert=False): + assert os.path.exists(path), f'>> {path}: File not found' + + image = Image.open(path) + print( + f'>> loaded input mask of size {image.width}x{image.height} from {path}' + ) + + if fit: + image = self._fit_image(image,(width,height)) + else: + image = self._squeeze_image(image) + + # convert into a black/white mask + image = self._mask_to_image(image,invert) + image = image.convert('RGB') + size = image.size + + # not quite sure what's going on here. It is copied from basunjindal's implementation + # image = image.resize((64, 64), resample=Image.Resampling.LANCZOS) + # BUG: We need to use the model's downsample factor rather than hardcoding "8" + from ldm.dream.generator.base import downsampling + image = image.resize((size[0]//downsampling, size[1]//downsampling), resample=Image.Resampling.LANCZOS) + image = np.array(image) + image = image.astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return image.to(self.device),size + + # The mask is expected to have the region to be inpainted + # with alpha transparency. It converts it into a black/white + # image with the transparent part black. + def _mask_to_image(self, init_mask, invert=False) -> Image: + if self._has_transparency(init_mask): + # Obtain the mask from the transparency channel + mask = Image.new(mode="L", size=init_mask.size, color=255) + mask.putdata(init_mask.getdata(band=3)) + if invert: + mask = ImageOps.invert(mask) + return mask + else: + print(f'>> No transparent pixels in this image. Will paint across entire image.') + return Image.new(mode="L", size=mask.size, color=0) + + def _has_transparency(self,image): + if image.info.get("transparency", None) is not None: + return True + if image.mode == "P": + transparent = image.info.get("transparency", -1) + for _, index in image.getcolors(): + if index == transparent: + return True + elif image.mode == "RGBA": + extrema = image.getextrema() + if extrema[3][0] < 255: + return True + return False + + def _squeeze_image(self,image): + x,y,resize_needed = self._resolution_check(image.width,image.height) + if resize_needed: + return InitImageResizer(image).resize(x,y) + return image + + + def _fit_image(self,image,max_dimensions): + w,h = max_dimensions + print( + f'>> image will be resized to fit inside a box {w}x{h} in size.' + ) + if image.width > image.height: + h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height + elif image.height > image.width: + w = None # ditto for w + else: + pass + image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally + print( + f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}' + ) + return image + + def _resolution_check(self, width, height, log=False): + resize_needed = False + w, h = map( + lambda x: x - x % 64, (width, height) + ) # resize to integer multiple of 64 + if h != height or w != width: + if log: + print( + f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}' + ) + height = h + width = w + resize_needed = True + + if (width * height) > (self.width * self.height): + print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") + + return width, height, resize_needed + + diff --git a/ldm/gfpgan/gfpgan_tools.py b/ldm/gfpgan/gfpgan_tools.py index 8fe8bb8d28..0de706ae42 100644 --- a/ldm/gfpgan/gfpgan_tools.py +++ b/ldm/gfpgan/gfpgan_tools.py @@ -13,8 +13,8 @@ opt = arg_parser.parse_args() model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path) gfpgan_model_exists = os.path.isfile(model_path) -def _run_gfpgan(image, strength, prompt, seed, upsampler_scale=4): - print(f'>> GFPGAN - Restoring Faces: {prompt} : seed:{seed}') +def run_gfpgan(image, strength, seed, upsampler_scale=4): + print(f'>> GFPGAN - Restoring Faces for image seed:{seed}') gfpgan = None with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) @@ -127,9 +127,9 @@ def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400): return bg_upsampler -def real_esrgan_upscale(image, strength, upsampler_scale, prompt, seed): +def real_esrgan_upscale(image, strength, upsampler_scale, seed): print( - f'>> Real-ESRGAN Upscaling: {prompt} : seed:{seed} : scale:{upsampler_scale}x' + f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x' ) with warnings.catch_warnings(): diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 672ba24dd1..3868540526 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -171,6 +171,7 @@ class DDIMSampler(object): ) return samples, intermediates + # This routine gets called from img2img @torch.no_grad() def ddim_sampling( self, @@ -270,6 +271,7 @@ class DDIMSampler(object): return img, intermediates + # This routine gets called from ddim_sampling() and decode() @torch.no_grad() def p_sample_ddim( self, @@ -372,14 +374,16 @@ class DDIMSampler(object): @torch.no_grad() def decode( - self, - x_latent, - cond, - t_start, - img_callback=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_original_steps=False, + self, + x_latent, + cond, + t_start, + img_callback=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + init_latent = None, + mask = None, ): timesteps = ( @@ -395,6 +399,8 @@ class DDIMSampler(object): iterator = tqdm(time_range, desc='Decoding image', total=total_steps) x_dec = x_latent + x0 = init_latent + for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full( @@ -403,6 +409,14 @@ class DDIMSampler(object): device=x_latent.device, dtype=torch.long, ) + + if mask is not None: + assert x0 is not None + xdec_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + x_dec = xdec_orig * mask + (1.0 - mask) * x_dec + x_dec, _ = self.p_sample_ddim( x_dec, cond, @@ -412,6 +426,7 @@ class DDIMSampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, ) + if img_callback: img_callback(x_dec, i) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 817e9bcdc4..7d14ad0938 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -13,7 +13,7 @@ def exists(val): def uniq(arr): - return {el: True for el in arr}.keys() + return{el: True for el in arr}.keys() def default(val, d): @@ -45,18 +45,19 @@ class GEGLU(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = ( - nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) - if not glu - else GEGLU(dim, inner_dim) - ) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) ) def forward(self, x): @@ -73,9 +74,7 @@ def zero_module(module): def Normalize(in_channels): - return torch.nn.GroupNorm( - num_groups=32, num_channels=in_channels, eps=1e-6, affine=True - ) + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) class LinearAttention(nn.Module): @@ -83,28 +82,17 @@ class LinearAttention(nn.Module): super().__init__() self.heads = heads hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) - q, k, v = rearrange( - qkv, - 'b (qkv heads c) h w -> qkv b heads c (h w)', - heads=self.heads, - qkv=3, - ) - k = k.softmax(dim=-1) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) - out = rearrange( - out, - 'b heads c (h w) -> b (heads c) h w', - heads=self.heads, - h=h, - w=w, - ) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) return self.to_out(out) @@ -114,18 +102,26 @@ class SpatialSelfAttention(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) def forward(self, x): h_ = x @@ -135,12 +131,12 @@ class SpatialSelfAttention(nn.Module): v = self.v(h_) # compute attention - b, c, h, w = q.shape + b,c,h,w = q.shape q = rearrange(q, 'b c h w -> b (h w) c') k = rearrange(k, 'b c h w -> b c (h w)') w_ = torch.einsum('bij,bjk->bik', q, k) - w_ = w_ * (int(c) ** (-0.5)) + w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values @@ -150,18 +146,16 @@ class SpatialSelfAttention(nn.Module): h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = self.proj_out(h_) - return x + h_ + return x+h_ class CrossAttention(nn.Module): - def __init__( - self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0 - ): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head**-0.5 + self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) @@ -169,7 +163,8 @@ class CrossAttention(nn.Module): self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) ) def forward(self, x, context=None, mask=None): @@ -179,59 +174,43 @@ class CrossAttention(nn.Module): context = default(context, x) k = self.to_k(context) v = self.to_v(context) + del context, x - q, k, v = map( - lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v) - ) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (8, 4096, 40) + del q, k if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) + del mask - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) + # attention, what we cannot get enough of, by halves + sim[4:] = sim[4:].softmax(dim=-1) + sim[:4] = sim[:4].softmax(dim=-1) - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) + sim = einsum('b i j, b j d -> b i d', sim, v) + sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) + return self.to_out(sim) class BasicTransformerBlock(nn.Module): - def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - ): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() - self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is a self-attention + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention( - query_dim=dim, - context_dim=context_dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - ) # is self-attn if context is none + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint def forward(self, x, context=None): - return checkpoint( - self._forward, (x, context), self.parameters(), self.checkpoint - ) + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) def _forward(self, x, context=None): x = x.contiguous() if x.device.type == 'mps' else x @@ -249,43 +228,29 @@ class SpatialTransformer(nn.Module): Then apply standard transformer action. Finally, reshape to image """ - - def __init__( - self, - in_channels, - n_heads, - d_head, - depth=1, - dropout=0.0, - context_dim=None, - ): + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): super().__init__() self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - n_heads, - d_head, - dropout=dropout, - context_dim=context_dim, - ) - for d in range(depth) - ] + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] ) - self.proj_out = zero_module( - nn.Conv2d( - inner_dim, in_channels, kernel_size=1, stride=1, padding=0 - ) - ) + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 808733c5fc..548c44fa49 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -1,856 +1,13 @@ -# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) +''' +This module is provided for backward compatibility with the +original (hasty) API. -# Derived from source code carrying the following copyrights -# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich -# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors +Please use ldm.generate instead. +''' -import torch -import numpy as np -import random -import os -import traceback -from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange, repeat -from torch import nn -from torchvision.utils import make_grid -from pytorch_lightning import seed_everything -from torch import autocast -from contextlib import contextmanager, nullcontext -import transformers -import time -import re -import sys +from ldm.generate import Generate -from ldm.util import instantiate_from_config -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler -from ldm.models.diffusion.ksampler import KSampler -from ldm.dream.pngwriter import PngWriter -from ldm.dream.image_util import InitImageResizer -from ldm.dream.devices import choose_autocast_device, choose_torch_device - -"""Simplified text to image API for stable diffusion/latent diffusion - -Example Usage: - -from ldm.simplet2i import T2I - -# Create an object with default values -t2i = T2I(model = // models/ldm/stable-diffusion-v1/model.ckpt - config = // configs/stable-diffusion/v1-inference.yaml - iterations = // how many times to run the sampling (1) - steps = // 50 - seed = // current system time - sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms - grid = // false - width = // image width, multiple of 64 (512) - height = // image height, multiple of 64 (512) - cfg_scale = // unconditional guidance scale (7.5) - ) - -# do the slow model initialization -t2i.load_model() - -# Do the fast inference & image generation. Any options passed here -# override the default values assigned during class initialization -# Will call load_model() if the model was not previously loaded and so -# may be slow at first. -# The method returns a list of images. Each row of the list is a sub-list of [filename,seed] -results = t2i.prompt2png(prompt = "an astronaut riding a horse", - outdir = "./outputs/samples", - iterations = 3) - -for row in results: - print(f'filename={row[0]}') - print(f'seed ={row[1]}') - -# Same thing, but using an initial image. -results = t2i.prompt2png(prompt = "an astronaut riding a horse", - outdir = "./outputs/, - iterations = 3, - init_img = "./sketches/horse+rider.png") - -for row in results: - print(f'filename={row[0]}') - print(f'seed ={row[1]}') - -# Same thing, but we return a series of Image objects, which lets you manipulate them, -# combine them, and save them under arbitrary names - -results = t2i.prompt2image(prompt = "an astronaut riding a horse" - outdir = "./outputs/") -for row in results: - im = row[0] - seed = row[1] - im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png') - im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg') - -Note that the old txt2img() and img2img() calls are deprecated but will -still work. -""" - - -class T2I: - """T2I class - Attributes - ---------- - model - config - iterations - steps - seed - sampler_name - width - height - cfg_scale - latent_channels - downsampling_factor - precision - strength - seamless - embedding_path - - The vast majority of these arguments default to reasonable values. -""" - - def __init__( - self, - iterations=1, - steps=50, - seed=None, - cfg_scale=7.5, - weights='models/ldm/stable-diffusion-v1/model.ckpt', - config='configs/stable-diffusion/v1-inference.yaml', - grid=False, - width=512, - height=512, - sampler_name='k_lms', - latent_channels=4, - downsampling_factor=8, - ddim_eta=0.0, # deterministic - precision='autocast', - full_precision=False, - strength=0.75, # default in scripts/img2img.py - seamless=False, - embedding_path=None, - device_type = 'cuda', - # just to keep track of this parameter when regenerating prompt - # needs to be replaced when new configuration system implemented. - latent_diffusion_weights=False, - ): - self.iterations = iterations - self.width = width - self.height = height - self.steps = steps - self.cfg_scale = cfg_scale - self.weights = weights - self.config = config - self.sampler_name = sampler_name - self.latent_channels = latent_channels - self.downsampling_factor = downsampling_factor - self.grid = grid - self.ddim_eta = ddim_eta - self.precision = precision - self.full_precision = True if choose_torch_device() == 'mps' else full_precision - self.strength = strength - self.seamless = seamless - self.embedding_path = embedding_path - self.device_type = device_type - self.model = None # empty for now - self.sampler = None - self.device = None - self.latent_diffusion_weights = latent_diffusion_weights - - if device_type == 'cuda' and not torch.cuda.is_available(): - device_type = choose_torch_device() - print(">> cuda not available, using device", device_type) - self.device = torch.device(device_type) - - # for VRAM usage statistics - device_type = choose_torch_device() - self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None - - if seed is None: - self.seed = self._new_seed() - else: - self.seed = seed - transformers.logging.set_verbosity_error() - - def prompt2png(self, prompt, outdir, **kwargs): - """ - Takes a prompt and an output directory, writes out the requested number - of PNG files, and returns an array of [[filename,seed],[filename,seed]...] - Optional named arguments are the same as those passed to T2I and prompt2image() - """ - results = self.prompt2image(prompt, **kwargs) - pngwriter = PngWriter(outdir) - prefix = pngwriter.unique_prefix() - outputs = [] - for image, seed in results: - name = f'{prefix}.{seed}.png' - path = pngwriter.save_image_and_prompt_to_png( - image, f'{prompt} -S{seed}', name) - outputs.append([path, seed]) - return outputs - - def txt2img(self, prompt, **kwargs): - outdir = kwargs.pop('outdir', 'outputs/img-samples') - return self.prompt2png(prompt, outdir, **kwargs) - - def img2img(self, prompt, **kwargs): - outdir = kwargs.pop('outdir', 'outputs/img-samples') - assert ( - 'init_img' in kwargs - ), 'call to img2img() must include the init_img argument' - return self.prompt2png(prompt, outdir, **kwargs) - - def prompt2image( - self, - # these are common - prompt, - iterations = None, - steps = None, - seed = None, - cfg_scale = None, - ddim_eta = None, - skip_normalize = False, - image_callback = None, - step_callback = None, - width = None, - height = None, - seamless = False, - # these are specific to img2img - init_img = None, - fit = False, - strength = None, - gfpgan_strength= 0, - save_original = False, - upscale = None, - sampler_name = None, - log_tokenization= False, - with_variations = None, - variation_amount = 0.0, - **args, - ): # eat up additional cruft - """ - ldm.prompt2image() is the common entry point for txt2img() and img2img() - It takes the following arguments: - prompt // prompt string (no default) - iterations // iterations (1); image count=iterations - steps // refinement steps per iteration - seed // seed for random number generator - width // width of image, in multiples of 64 (512) - height // height of image, in multiples of 64 (512) - cfg_scale // how strongly the prompt influences the image (7.5) (must be >1) - seamless // whether the generated image should tile - init_img // path to an initial image - its dimensions override width and height - strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely - gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely - ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) - step_callback // a function or method that will be called each step - image_callback // a function or method that will be called each time an image is generated - with_variations // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation - variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image) - - To use the step callback, define a function that receives two arguments: - - Image GPU data - - The step number - - To use the image callback, define a function of method that receives two arguments, an Image object - and the seed. You can then do whatever you like with the image, including converting it to - different formats and manipulating it. For example: - - def process_image(image,seed): - image.save(f{'images/seed.png'}) - - The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code - to create the requested output directory, select a unique informative name for each image, and - write the prompt into the PNG metadata. - """ - # TODO: convert this into a getattr() loop - steps = steps or self.steps - width = width or self.width - height = height or self.height - seamless = seamless or self.seamless - cfg_scale = cfg_scale or self.cfg_scale - ddim_eta = ddim_eta or self.ddim_eta - iterations = iterations or self.iterations - strength = strength or self.strength - self.log_tokenization = log_tokenization - with_variations = [] if with_variations is None else with_variations - - model = ( - self.load_model() - ) # will instantiate the model or return it from cache - for m in model.modules(): - if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - m.padding_mode = 'circular' if seamless else m._orig_padding_mode - - assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0' - assert ( - 0.0 <= strength <= 1.0 - ), 'can only work with strength in [0.0, 1.0]' - assert ( - 0.0 <= variation_amount <= 1.0 - ), '-v --variation_amount must be in [0.0, 1.0]' - - if len(with_variations) > 0 or variation_amount > 0.0: - assert seed is not None,\ - 'seed must be specified when using with_variations' - if variation_amount == 0.0: - assert iterations == 1,\ - 'when using --with_variations, multiple iterations are only possible when using --variation_amount' - assert all(0 <= weight <= 1 for _, weight in with_variations),\ - f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}' - - seed = seed or self.seed - width, height, _ = self._resolution_check(width, height, log=True) - - # TODO: - Check if this is still necessary to run on M1 devices. - # - Move code into ldm.dream.devices to live alongside other - # special-hardware casing code. - if self.precision == 'autocast' and torch.cuda.is_available(): - scope = autocast - else: - scope = nullcontext - - if sampler_name and (sampler_name != self.sampler_name): - self.sampler_name = sampler_name - self._set_sampler() - - tic = time.time() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - results = list() - - try: - if init_img: - assert os.path.exists(init_img), f'{init_img}: File not found' - init_image = self._load_img(init_img, width, height, fit).to(self.device) - with scope(self.device.type): - init_latent = self.model.get_first_stage_encoding( - self.model.encode_first_stage(init_image) - ) # move to latent space - - make_image = self._img2img( - prompt, - steps=steps, - cfg_scale=cfg_scale, - ddim_eta=ddim_eta, - skip_normalize=skip_normalize, - init_latent=init_latent, - strength=strength, - callback=step_callback, - ) - else: - init_latent = None - make_image = self._txt2img( - prompt, - steps=steps, - cfg_scale=cfg_scale, - ddim_eta=ddim_eta, - skip_normalize=skip_normalize, - width=width, - height=height, - callback=step_callback, - ) - - initial_noise = None - if variation_amount > 0 or len(with_variations) > 0: - # use fixed initial noise plus random noise per iteration - seed_everything(seed) - initial_noise = self._get_noise(init_latent,width,height) - for v_seed, v_weight in with_variations: - seed = v_seed - seed_everything(seed) - next_noise = self._get_noise(init_latent,width,height) - initial_noise = self.slerp(v_weight, initial_noise, next_noise) - if variation_amount > 0: - random.seed() # reset RNG to an actually random state, so we can get a random seed for variations - seed = random.randrange(0,np.iinfo(np.uint32).max) - - device_type = choose_autocast_device(self.device) - with scope(device_type), self.model.ema_scope(): - for n in trange(iterations, desc='Generating'): - x_T = None - if variation_amount > 0: - seed_everything(seed) - target_noise = self._get_noise(init_latent,width,height) - x_T = self.slerp(variation_amount, initial_noise, target_noise) - elif initial_noise is not None: - # i.e. we specified particular variations - x_T = initial_noise - else: - seed_everything(seed) - if self.device.type == 'mps': - x_T = self._get_noise(init_latent,width,height) - # make_image will do the equivalent of get_noise itself - image = make_image(x_T) - results.append([image, seed]) - if image_callback is not None: - image_callback(image, seed) - seed = self._new_seed() - - if upscale is not None or gfpgan_strength > 0: - for result in results: - image, seed = result - try: - if upscale is not None: - from ldm.gfpgan.gfpgan_tools import ( - real_esrgan_upscale, - ) - if len(upscale) < 2: - upscale.append(0.75) - image = real_esrgan_upscale( - image, - upscale[1], - int(upscale[0]), - prompt, - seed, - ) - if gfpgan_strength > 0: - from ldm.gfpgan.gfpgan_tools import _run_gfpgan - - image = _run_gfpgan( - image, gfpgan_strength, prompt, seed, 1 - ) - except Exception as e: - print( - f'>> Error running RealESRGAN - Your image was not upscaled.\n{e}' - ) - if image_callback is not None: - image_callback(image, seed, upscaled=True) - else: # no callback passed, so we simply replace old image with rescaled one - result[0] = image - - except KeyboardInterrupt: - print('*interrupted*') - print( - '>> Partial results will be returned; if --grid was requested, nothing will be returned.' - ) - except RuntimeError as e: - print(traceback.format_exc(), file=sys.stderr) - print('>> Are you sure your system has an adequate NVIDIA GPU?') - - toc = time.time() - print('>> Usage stats:') - print( - f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic) - ) - print( - f'>> Max VRAM used for this generation:', - '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9), - ) - - if self.session_peakmem: - self.session_peakmem = max( - self.session_peakmem, torch.cuda.max_memory_allocated() - ) - print( - f'>> Max VRAM used since script start: ', - '%4.2fG' % (self.session_peakmem / 1e9), - ) - return results - - @torch.no_grad() - def _txt2img( - self, - prompt, - steps, - cfg_scale, - ddim_eta, - skip_normalize, - width, - height, - callback, - ): - """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it - """ - - sampler = self.sampler - - @torch.no_grad() - def make_image(x_T): - uc, c = self._get_uc_and_c(prompt, skip_normalize) - shape = [ - self.latent_channels, - height // self.downsampling_factor, - width // self.downsampling_factor, - ] - samples, _ = sampler.sample( - batch_size=1, - S=steps, - x_T=x_T, - conditioning=c, - shape=shape, - verbose=False, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - eta=ddim_eta, - img_callback=callback - ) - return self._sample_to_image(samples) - return make_image - - @torch.no_grad() - def _img2img( - self, - prompt, - steps, - cfg_scale, - ddim_eta, - skip_normalize, - init_latent, - strength, - callback, # Currently not implemented for img2img - ): - """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it - """ - - # PLMS sampler not supported yet, so ignore previous sampler - if self.sampler_name != 'ddim': - print( - f">> sampler '{self.sampler_name}' is not yet supported. Using DDIM sampler" - ) - sampler = DDIMSampler(self.model, device=self.device) - else: - sampler = self.sampler - - sampler.make_schedule( - ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False - ) - - t_enc = int(strength * steps) - - @torch.no_grad() - def make_image(x_T): - uc, c = self._get_uc_and_c(prompt, skip_normalize) - - # encode (scaled latent) - z_enc = sampler.stochastic_encode( - init_latent, - torch.tensor([t_enc]).to(self.device), - noise=x_T - ) - # decode it - samples = sampler.decode( - z_enc, - c, - t_enc, - img_callback=callback, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - ) - return self._sample_to_image(samples) - return make_image - - # TODO: does this actually need to run every loop? does anything in it vary by random seed? - def _get_uc_and_c(self, prompt, skip_normalize): - - uc = self.model.get_learned_conditioning(['']) - - # get weighted sub-prompts - weighted_subprompts = T2I._split_weighted_subprompts( - prompt, skip_normalize) - - if len(weighted_subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # normalize each "sub prompt" and add it - for subprompt, weight in weighted_subprompts: - self._log_tokenization(subprompt) - c = torch.add( - c, - self.model.get_learned_conditioning([subprompt]), - alpha=weight, - ) - else: # just standard 1 prompt - self._log_tokenization(prompt) - c = self.model.get_learned_conditioning([prompt]) - return (uc, c) - - def _sample_to_image(self, samples): - x_samples = self.model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - if len(x_samples) != 1: - raise Exception( - f'>> expected to get a single image, but got {len(x_samples)}') - x_sample = 255.0 * rearrange( - x_samples[0].cpu().numpy(), 'c h w -> h w c' - ) - return Image.fromarray(x_sample.astype(np.uint8)) - - def _new_seed(self): - self.seed = random.randrange(0, np.iinfo(np.uint32).max) - return self.seed - - def load_model(self): - """Load and initialize the model from configuration variables passed at object creation time""" - if self.model is None: - seed_everything(self.seed) - try: - config = OmegaConf.load(self.config) - model = self._load_model_from_config(config, self.weights) - if self.embedding_path is not None: - model.embedding_manager.load( - self.embedding_path, self.full_precision - ) - self.model = model.to(self.device) - # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here - self.model.cond_stage_model.device = self.device - except AttributeError as e: - print(f'>> Error loading model. {str(e)}', file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - raise SystemExit from e - - self._set_sampler() - - for m in self.model.modules(): - if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - m._orig_padding_mode = m.padding_mode - - return self.model - - # returns a tensor filled with random numbers from a normal distribution - def _get_noise(self,init_latent,width,height): - if init_latent is not None: - if self.device.type == 'mps': - return torch.randn_like(init_latent, device='cpu').to(self.device) - else: - return torch.randn_like(init_latent, device=self.device) - else: - if self.device.type == 'mps': - return torch.randn([1, - self.latent_channels, - height // self.downsampling_factor, - width // self.downsampling_factor], - device='cpu').to(self.device) - else: - return torch.randn([1, - self.latent_channels, - height // self.downsampling_factor, - width // self.downsampling_factor], - device=self.device) - - def _set_sampler(self): - msg = f'>> Setting Sampler to {self.sampler_name}' - if self.sampler_name == 'plms': - self.sampler = PLMSSampler(self.model, device=self.device) - elif self.sampler_name == 'ddim': - self.sampler = DDIMSampler(self.model, device=self.device) - elif self.sampler_name == 'k_dpm_2_a': - self.sampler = KSampler( - self.model, 'dpm_2_ancestral', device=self.device - ) - elif self.sampler_name == 'k_dpm_2': - self.sampler = KSampler(self.model, 'dpm_2', device=self.device) - elif self.sampler_name == 'k_euler_a': - self.sampler = KSampler( - self.model, 'euler_ancestral', device=self.device - ) - elif self.sampler_name == 'k_euler': - self.sampler = KSampler(self.model, 'euler', device=self.device) - elif self.sampler_name == 'k_heun': - self.sampler = KSampler(self.model, 'heun', device=self.device) - elif self.sampler_name == 'k_lms': - self.sampler = KSampler(self.model, 'lms', device=self.device) - else: - msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms' - self.sampler = PLMSSampler(self.model, device=self.device) - - print(msg) - - def _load_model_from_config(self, config, ckpt): - print(f'>> Loading model from {ckpt}') - pl_sd = torch.load(ckpt, map_location='cpu') - # if "global_step" in pl_sd: - # print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd['state_dict'] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - model.to(self.device) - model.eval() - if self.full_precision: - print( - 'Using slower but more accurate full-precision math (--full_precision)' - ) - else: - print( - '>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.' - ) - model.half() - return model - - def _load_img(self, path, width, height, fit=False): - with Image.open(path) as img: - image = img.convert('RGB') - print( - f'>> loaded input image of size {image.width}x{image.height} from {path}' - ) - - # The logic here is: - # 1. If "fit" is true, then the image will be fit into the bounding box defined - # by width and height. It will do this in a way that preserves the init image's - # aspect ratio while preventing letterboxing. This means that if there is - # leftover horizontal space after rescaling the image to fit in the bounding box, - # the generated image's width will be reduced to the rescaled init image's width. - # Similarly for the vertical space. - # 2. Otherwise, if "fit" is false, then the image will be scaled, preserving its - # aspect ratio, to the nearest multiple of 64. Large images may generate an - # unexpected OOM error. - if fit: - image = self._fit_image(image,(width,height)) - else: - image = self._squeeze_image(image) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - def _squeeze_image(self,image): - x,y,resize_needed = self._resolution_check(image.width,image.height) - if resize_needed: - return InitImageResizer(image).resize(x,y) - return image - - - def _fit_image(self,image,max_dimensions): - w,h = max_dimensions - print( - f'>> image will be resized to fit inside a box {w}x{h} in size.' - ) - if image.width > image.height: - h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height - elif image.height > image.width: - w = None # ditto for w - else: - pass - image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally - print( - f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}' - ) - return image - - - # TO DO: Move this and related weighted subprompt code into its own module. - def _split_weighted_subprompts(text, skip_normalize=False): - """ - grabs all text up to the first occurrence of ':' - uses the grabbed text as a sub-prompt, and takes the value following ':' as weight - if ':' has no value defined, defaults to 1.0 - repeats until no text remaining - """ - prompt_parser = re.compile(""" - (?P # capture group for 'prompt' - (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' - ) # end 'prompt' - (?: # non-capture group - :+ # match one or more ':' characters - (?P # capture group for 'weight' - -?\d+(?:\.\d+)? # match positive or negative integer or decimal number - )? # end weight capture group, make optional - \s* # strip spaces after weight - | # OR - $ # else, if no ':' then match end of line - ) # end non-capture group - """, re.VERBOSE) - parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( - match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] - if skip_normalize: - return parsed_prompts - weight_sum = sum(map(lambda x: x[1], parsed_prompts)) - if weight_sum == 0: - print( - "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") - equal_weight = 1 / len(parsed_prompts) - return [(x[0], equal_weight) for x in parsed_prompts] - return [(x[0], x[1] / weight_sum) for x in parsed_prompts] - - # shows how the prompt is tokenized - # usually tokens have '' to indicate end-of-word, - # but for readability it has been replaced with ' ' - def _log_tokenization(self, text): - if not self.log_tokenization: - return - tokens = self.model.cond_stage_model.tokenizer._tokenize(text) - tokenized = "" - discarded = "" - usedTokens = 0 - totalTokens = len(tokens) - for i in range(0, totalTokens): - token = tokens[i].replace('', ' ') - # alternate color - s = (usedTokens % 6) + 1 - if i < self.model.cond_stage_model.max_length: - tokenized = tokenized + f"\x1b[0;3{s};40m{token}" - usedTokens += 1 - else: # over max token length - discarded = discarded + f"\x1b[0;3{s};40m{token}" - print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m") - if discarded != "": - print( - f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m") - - def _resolution_check(self, width, height, log=False): - resize_needed = False - w, h = map( - lambda x: x - x % 64, (width, height) - ) # resize to integer multiple of 64 - if h != height or w != width: - if log: - print( - f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}' - ) - height = h - width = w - resize_needed = True - - if (width * height) > (self.width * self.height): - print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") - - return width, height, resize_needed - - - def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995): - ''' - Spherical linear interpolation - Args: - t (float/np.ndarray): Float value between 0.0 and 1.0 - v0 (np.ndarray): Starting vector - v1 (np.ndarray): Final vector - DOT_THRESHOLD (float): Threshold for considering the two vectors as - colineal. Not recommended to alter this. - Returns: - v2 (np.ndarray): Interpolation vector between v0 and v1 - ''' - inputs_are_torch = False - if not isinstance(v0, np.ndarray): - inputs_are_torch = True - v0 = v0.detach().cpu().numpy() - if not isinstance(v1, np.ndarray): - inputs_are_torch = True - v1 = v1.detach().cpu().numpy() - - dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) - if np.abs(dot) > DOT_THRESHOLD: - v2 = (1 - t) * v0 + t * v1 - else: - theta_0 = np.arccos(dot) - sin_theta_0 = np.sin(theta_0) - theta_t = theta_0 * t - sin_theta_t = np.sin(theta_t) - s0 = np.sin(theta_0 - theta_t) / sin_theta_0 - s1 = sin_theta_t / sin_theta_0 - v2 = s0 * v0 + s1 * v1 - - if inputs_are_torch: - v2 = torch.from_numpy(v2).to(self.device) - - return v2 +class T2I(Generate): + def __init__(self,**kwargs): + print(f'>> The ldm.simplet2i module is deprecated. Use ldm.generate instead. It is a drop-in replacement.') + super().__init__(kwargs) diff --git a/scripts/dream.py b/scripts/dream.py index 6c4a110c4e..11ab809890 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -40,7 +40,7 @@ def main(): print('* Initializing, be patient...\n') sys.path.append('.') from pytorch_lightning import logging - from ldm.simplet2i import T2I + from ldm.generate import Generate # these two lines prevent a horrible warning message from appearing # when the frozen CLIP tokenizer is imported @@ -52,19 +52,18 @@ def main(): # defaults passed on the command line. # additional parameters will be added (or overriden) during # the user input loop - t2i = T2I( - width=width, - height=height, - sampler_name=opt.sampler_name, - weights=weights, - full_precision=opt.full_precision, - config=config, - grid = opt.grid, + t2i = Generate( + width = width, + height = height, + sampler_name = opt.sampler_name, + weights = weights, + full_precision = opt.full_precision, + config = config, + grid = opt.grid, # this is solely for recreating the prompt - latent_diffusion_weights=opt.laion400m, - seamless=opt.seamless, - embedding_path=opt.embedding_path, - device_type=opt.device + seamless = opt.seamless, + embedding_path = opt.embedding_path, + device_type = opt.device ) # make sure the output directory exists @@ -567,6 +566,17 @@ def create_cmd_parser(): type=str, help='Path to input image for img2img mode (supersedes width and height)', ) + parser.add_argument( + '-M', + '--mask', + type=str, + help='Path to inpainting mask; transparent areas will be painted over', + ) + parser.add_argument( + '--invert_mask', + action='store_true', + help='Invert the inpainting mask; opaque areas will be painted over', + ) parser.add_argument( '-T', '-fit',