2022-08-24 13:22:04 +00:00
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
# Derived from source code carrying the following copyrights
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
2022-08-25 04:42:37 +00:00
import torch
import numpy as np
import random
import sys
import os
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm , trange
from itertools import islice
from einops import rearrange , repeat
from torchvision . utils import make_grid
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager , nullcontext
import transformers
import time
import re
import traceback
from ldm . util import instantiate_from_config
from ldm . models . diffusion . ddim import DDIMSampler
from ldm . models . diffusion . plms import PLMSSampler
from ldm . models . diffusion . ksampler import KSampler
from ldm . dream_util import PngWriter
2022-08-24 13:22:04 +00:00
2022-08-17 01:34:37 +00:00
""" Simplified text to image API for stable diffusion/latent diffusion
Example Usage :
from ldm . simplet2i import T2I
2022-08-25 04:42:37 +00:00
2022-08-17 01:34:37 +00:00
# Create an object with default values
2022-08-25 04:42:37 +00:00
t2i = T2I ( model = < path > / / models / ldm / stable - diffusion - v1 / model . ckpt
config = < path > / / configs / stable - diffusion / v1 - inference . yaml
2022-08-17 16:35:49 +00:00
iterations = < integer > / / how many times to run the sampling ( 1 )
2022-08-25 04:42:37 +00:00
batch_size = < integer > / / how many images to generate per sampling ( 1 )
2022-08-17 01:34:37 +00:00
steps = < integer > / / 50
seed = < integer > / / current system time
2022-08-23 21:25:39 +00:00
sampler_name = [ ' ddim ' , ' k_dpm_2_a ' , ' k_dpm_2 ' , ' k_euler_a ' , ' k_euler ' , ' k_heun ' , ' k_lms ' , ' plms ' ] / / k_lms
2022-08-17 01:34:37 +00:00
grid = < boolean > / / false
width = < integer > / / image width , multiple of 64 ( 512 )
height = < integer > / / image height , multiple of 64 ( 512 )
cfg_scale = < float > / / unconditional guidance scale ( 7.5 )
)
2022-08-17 16:35:49 +00:00
2022-08-17 01:34:37 +00:00
# do the slow model initialization
t2i . load_model ( )
# Do the fast inference & image generation. Any options passed here
# override the default values assigned during class initialization
2022-08-25 04:42:37 +00:00
# Will call load_model() if the model was not previously loaded and so
# may be slow at first.
2022-08-17 16:35:49 +00:00
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
2022-08-25 04:42:37 +00:00
results = t2i . prompt2png ( prompt = " an astronaut riding a horse " ,
outdir = " ./outputs/samples " ,
iterations = 3 )
2022-08-17 16:35:49 +00:00
for row in results :
print ( f ' filename= { row [ 0 ] } ' )
print ( f ' seed = { row [ 1 ] } ' )
2022-08-18 14:47:53 +00:00
# Same thing, but using an initial image.
2022-08-25 04:42:37 +00:00
results = t2i . prompt2png ( prompt = " an astronaut riding a horse " ,
outdir = " ./outputs/,
iterations = 3 ,
init_img = " ./sketches/horse+rider.png " )
2022-08-17 01:34:37 +00:00
2022-08-18 14:47:53 +00:00
for row in results :
print ( f ' filename= { row [ 0 ] } ' )
print ( f ' seed = { row [ 1 ] } ' )
2022-08-17 16:35:49 +00:00
2022-08-25 04:42:37 +00:00
# Same thing, but we return a series of Image objects, which lets you manipulate them,
# combine them, and save them under arbitrary names
results = t2i . prompt2image ( prompt = " an astronaut riding a horse "
outdir = " ./outputs/ " )
for row in results :
im = row [ 0 ]
seed = row [ 1 ]
im . save ( f ' ./outputs/samples/an_astronaut_riding_a_horse- { seed } .png ' )
im . thumbnail ( 100 , 100 ) . save ( ' ./outputs/samples/astronaut_thumb.jpg ' )
Note that the old txt2img ( ) and img2img ( ) calls are deprecated but will
still work .
"""
2022-08-17 01:34:37 +00:00
class T2I :
""" T2I class
Attributes
- - - - - - - - - -
model
config
iterations
2022-08-18 16:43:59 +00:00
batch_size
2022-08-17 01:34:37 +00:00
steps
seed
2022-08-22 04:12:16 +00:00
sampler_name
2022-08-17 01:34:37 +00:00
width
height
cfg_scale
latent_channels
downsampling_factor
precision
2022-08-18 14:47:53 +00:00
strength
2022-08-24 03:16:01 +00:00
embedding_path
2022-08-22 02:48:40 +00:00
The vast majority of these arguments default to reasonable values .
2022-08-17 01:34:37 +00:00
"""
def __init__ ( self ,
2022-08-18 16:43:59 +00:00
batch_size = 1 ,
2022-08-17 01:34:37 +00:00
iterations = 1 ,
steps = 50 ,
seed = None ,
cfg_scale = 7.5 ,
weights = " models/ldm/stable-diffusion-v1/model.ckpt " ,
2022-08-23 04:46:22 +00:00
config = " configs/stable-diffusion/v1-inference.yaml " ,
2022-08-24 23:47:59 +00:00
width = 512 ,
height = 512 ,
2022-08-22 04:12:16 +00:00
sampler_name = " klms " ,
2022-08-17 01:34:37 +00:00
latent_channels = 4 ,
downsampling_factor = 8 ,
ddim_eta = 0.0 , # deterministic
2022-08-18 14:47:53 +00:00
precision = ' autocast ' ,
2022-08-21 23:57:48 +00:00
full_precision = False ,
2022-08-22 04:12:16 +00:00
strength = 0.75 , # default in scripts/img2img.py
2022-08-24 03:16:01 +00:00
embedding_path = None ,
2022-08-23 17:49:17 +00:00
latent_diffusion_weights = False , # just to keep track of this parameter when regenerating prompt
2022-08-26 02:57:30 +00:00
device = ' cuda ' ,
gfpgan = None ,
2022-08-17 01:34:37 +00:00
) :
2022-08-18 16:43:59 +00:00
self . batch_size = batch_size
2022-08-17 01:34:37 +00:00
self . iterations = iterations
self . width = width
self . height = height
self . steps = steps
self . cfg_scale = cfg_scale
2022-08-22 02:48:40 +00:00
self . weights = weights
2022-08-17 01:34:37 +00:00
self . config = config
2022-08-22 04:12:16 +00:00
self . sampler_name = sampler_name
2022-08-17 01:34:37 +00:00
self . latent_channels = latent_channels
self . downsampling_factor = downsampling_factor
self . ddim_eta = ddim_eta
self . precision = precision
2022-08-21 23:57:48 +00:00
self . full_precision = full_precision
2022-08-18 14:47:53 +00:00
self . strength = strength
2022-08-24 03:16:01 +00:00
self . embedding_path = embedding_path
2022-08-17 01:34:37 +00:00
self . model = None # empty for now
self . sampler = None
2022-08-22 02:48:40 +00:00
self . latent_diffusion_weights = latent_diffusion_weights
2022-08-23 17:49:17 +00:00
self . device = device
2022-08-26 02:57:30 +00:00
self . gfpgan = gfpgan
2022-08-17 01:34:37 +00:00
if seed is None :
self . seed = self . _new_seed ( )
else :
self . seed = seed
2022-08-25 04:42:37 +00:00
transformers . logging . set_verbosity_error ( )
def prompt2png ( self , prompt , outdir , * * kwargs ) :
'''
Takes a prompt and an output directory , writes out the requested number
of PNG files , and returns an array of [ [ filename , seed ] , [ filename , seed ] . . . ]
Optional named arguments are the same as those passed to T2I and prompt2image ( )
'''
results = self . prompt2image ( prompt , * * kwargs )
pngwriter = PngWriter ( outdir , prompt , kwargs . get ( ' batch_size ' , self . batch_size ) )
for r in results :
metadata_str = f ' prompt2png( " { prompt } " { kwargs } seed= { r [ 1 ] } ' # gets written into the PNG
pngwriter . write_image ( r [ 0 ] , r [ 1 ] )
return pngwriter . files_written
def txt2img ( self , prompt , * * kwargs ) :
outdir = kwargs . get ( ' outdir ' , ' outputs/img-samples ' )
return self . prompt2png ( prompt , outdir , * * kwargs )
def img2img ( self , prompt , * * kwargs ) :
outdir = kwargs . get ( ' outdir ' , ' outputs/img-samples ' )
assert ' init_img ' in kwargs , ' call to img2img() must include the init_img argument '
return self . prompt2png ( prompt , outdir , * * kwargs )
2022-08-24 23:47:59 +00:00
def prompt2image ( self ,
# these are common
prompt ,
batch_size = None ,
iterations = None ,
steps = None ,
seed = None ,
cfg_scale = None ,
ddim_eta = None ,
skip_normalize = False ,
image_callback = None ,
# these are specific to txt2img
width = None ,
height = None ,
# these are specific to img2img
init_img = None ,
strength = None ,
2022-08-26 02:57:30 +00:00
gfpgan_strength = None ,
2022-08-24 23:47:59 +00:00
variants = None ,
* * args ) : # eat up additional cruft
2022-08-25 04:42:37 +00:00
'''
ldm . prompt2image ( ) is the common entry point for txt2img ( ) and img2img ( )
It takes the following arguments :
prompt / / prompt string ( no default )
iterations / / iterations ( 1 ) ; image count = iterations x batch_size
batch_size / / images per iteration ( 1 )
steps / / refinement steps per iteration
seed / / seed for random number generator
width / / width of image , in multiples of 64 ( 512 )
height / / height of image , in multiples of 64 ( 512 )
cfg_scale / / how strongly the prompt influences the image ( 7.5 ) ( must be > 1 )
init_img / / path to an initial image - its dimensions override width and height
strength / / strength for noising / unnoising init_img . 0.0 preserves image exactly , 1.0 replaces it completely
2022-08-26 02:57:30 +00:00
gfpgan_strength / / strength for GFPGAN . 0.0 preserves image exactly , 1.0 replaces it completely
2022-08-25 04:42:37 +00:00
ddim_eta / / image randomness ( eta = 0.0 means the same seed always produces the same image )
variants / / if > 0 , the 1 st generated image will be passed back to img2img to generate the requested number of variants
callback / / a function or method that will be called each time an image is generated
To use the callback , define a function of method that receives two arguments , an Image object
and the seed . You can then do whatever you like with the image , including converting it to
different formats and manipulating it . For example :
def process_image ( image , seed ) :
image . save ( f { ' images/seed.png ' } )
The callback used by the prompt2png ( ) can be found in ldm / dream_util . py . It contains code
to create the requested output directory , select a unique informative name for each image , and
write the prompt into the PNG metadata .
'''
2022-08-17 01:34:37 +00:00
steps = steps or self . steps
seed = seed or self . seed
width = width or self . width
height = height or self . height
cfg_scale = cfg_scale or self . cfg_scale
ddim_eta = ddim_eta or self . ddim_eta
2022-08-18 16:43:59 +00:00
batch_size = batch_size or self . batch_size
2022-08-17 01:34:37 +00:00
iterations = iterations or self . iterations
2022-08-24 21:52:34 +00:00
strength = strength or self . strength
2022-08-17 01:34:37 +00:00
model = self . load_model ( ) # will instantiate the model or return it from cache
2022-08-24 13:22:04 +00:00
assert cfg_scale > 1.0 , " CFG_Scale (-C) must be >1.0 "
2022-08-24 21:52:34 +00:00
assert 0. < = strength < = 1. , ' can only work with strength in [0.0, 1.0] '
2022-08-25 04:42:37 +00:00
w = int ( width / 64 ) * 64
h = int ( height / 64 ) * 64
if h != height or w != width :
print ( f ' Height and width must be multiples of 64. Resizing to { h } x { w } ' )
height = h
width = w
2022-08-24 13:22:04 +00:00
2022-08-18 16:43:59 +00:00
data = [ batch_size * [ prompt ] ]
2022-08-24 21:52:34 +00:00
scope = autocast if self . precision == " autocast " else nullcontext
2022-08-17 01:34:37 +00:00
2022-08-24 21:52:34 +00:00
tic = time . time ( )
if init_img :
assert os . path . exists ( init_img ) , f ' { init_img } : File not found '
results = self . _img2img ( prompt ,
data = data , precision_scope = scope ,
batch_size = batch_size , iterations = iterations ,
steps = steps , seed = seed , cfg_scale = cfg_scale , ddim_eta = ddim_eta ,
skip_normalize = skip_normalize ,
2022-08-26 02:57:30 +00:00
init_img = init_img , strength = strength ,
gfpgan_strength = gfpgan_strength , variants = variants ,
2022-08-24 21:52:34 +00:00
callback = image_callback )
else :
results = self . _txt2img ( prompt ,
data = data , precision_scope = scope ,
batch_size = batch_size , iterations = iterations ,
steps = steps , seed = seed , cfg_scale = cfg_scale , ddim_eta = ddim_eta ,
skip_normalize = skip_normalize ,
2022-08-26 02:57:30 +00:00
gfpgan_strength = gfpgan_strength ,
2022-08-24 21:52:34 +00:00
width = width , height = height ,
callback = image_callback )
toc = time . time ( )
print ( f ' { len ( results ) } images generated in ' , " %4.2f s " % ( toc - tic ) )
return results
@torch.no_grad ( )
def _txt2img ( self , prompt ,
data , precision_scope ,
batch_size , iterations ,
steps , seed , cfg_scale , ddim_eta ,
skip_normalize ,
2022-08-26 02:57:30 +00:00
gfpgan_strength ,
2022-08-24 21:52:34 +00:00
width , height ,
2022-08-24 23:47:59 +00:00
callback ) : # the callback is called each time a new Image is generated
2022-08-24 21:52:34 +00:00
"""
Generate an image from the prompt , writing iteration images into the outdir
The output is a list of lists in the format : [ [ image1 , seed1 ] , [ image2 , seed2 ] , . . . ]
"""
2022-08-17 01:34:37 +00:00
sampler = self . sampler
images = list ( )
2022-08-23 04:51:38 +00:00
image_count = 0
2022-08-23 05:23:14 +00:00
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
2022-08-23 04:51:38 +00:00
try :
2022-08-24 23:47:59 +00:00
with precision_scope ( self . device . type ) , self . model . ema_scope ( ) :
2022-08-23 17:49:17 +00:00
all_samples = list ( )
for n in trange ( iterations , desc = " Sampling " ) :
seed_everything ( seed )
for prompts in tqdm ( data , desc = " data " , dynamic_ncols = True ) :
uc = None
if cfg_scale != 1.0 :
2022-08-24 23:47:59 +00:00
uc = self . model . get_learned_conditioning ( batch_size * [ " " ] )
2022-08-23 17:49:17 +00:00
if isinstance ( prompts , tuple ) :
prompts = list ( prompts )
# weighted sub-prompts
subprompts , weights = T2I . _split_weighted_subprompts ( prompts [ 0 ] )
if len ( subprompts ) > 1 :
# i dont know if this is correct.. but it works
c = torch . zeros_like ( uc )
# get total weight for normalizing
totalWeight = sum ( weights )
# normalize each "sub prompt" and add it
for i in range ( 0 , len ( subprompts ) ) :
weight = weights [ i ]
if not skip_normalize :
weight = weight / totalWeight
2022-08-24 23:47:59 +00:00
c = torch . add ( c , self . model . get_learned_conditioning ( subprompts [ i ] ) , alpha = weight )
2022-08-23 17:49:17 +00:00
else : # just standard 1 prompt
2022-08-24 23:47:59 +00:00
c = self . model . get_learned_conditioning ( prompts )
2022-08-23 17:49:17 +00:00
shape = [ self . latent_channels , height / / self . downsampling_factor , width / / self . downsampling_factor ]
samples_ddim , _ = sampler . sample ( S = steps ,
2022-08-24 21:52:34 +00:00
conditioning = c ,
batch_size = batch_size ,
shape = shape ,
verbose = False ,
unconditional_guidance_scale = cfg_scale ,
unconditional_conditioning = uc ,
eta = ddim_eta )
2022-08-23 17:49:17 +00:00
2022-08-24 23:47:59 +00:00
x_samples_ddim = self . model . decode_first_stage ( samples_ddim )
2022-08-23 17:49:17 +00:00
x_samples_ddim = torch . clamp ( ( x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
2022-08-24 21:52:34 +00:00
for x_sample in x_samples_ddim :
x_sample = 255. * rearrange ( x_sample . cpu ( ) . numpy ( ) , ' c h w -> h w c ' )
image = Image . fromarray ( x_sample . astype ( np . uint8 ) )
2022-08-26 02:57:30 +00:00
if gfpgan_strength > 0 :
image = self . _run_gfpgan ( image , gfpgan_strength )
2022-08-24 21:52:34 +00:00
images . append ( [ image , seed ] )
if callback is not None :
callback ( image , seed )
2022-08-23 17:49:17 +00:00
seed = self . _new_seed ( )
2022-08-23 04:51:38 +00:00
except KeyboardInterrupt :
print ( ' *interrupted* ' )
print ( ' Partial results will be returned; if --grid was requested, nothing will be returned. ' )
except RuntimeError as e :
print ( str ( e ) )
2022-08-17 01:34:37 +00:00
return images
2022-08-23 17:49:17 +00:00
@torch.no_grad ( )
2022-08-24 21:52:34 +00:00
def _img2img ( self , prompt ,
data , precision_scope ,
batch_size , iterations ,
steps , seed , cfg_scale , ddim_eta ,
skip_normalize ,
2022-08-26 02:57:30 +00:00
gfpgan_strength ,
2022-08-24 21:52:34 +00:00
init_img , strength , variants ,
callback ) :
2022-08-18 14:47:53 +00:00
"""
Generate an image from the prompt and the initial image , writing iteration images into the outdir
2022-08-24 21:52:34 +00:00
The output is a list of lists in the format : [ [ image , seed1 ] , [ image , seed2 ] , . . . ]
2022-08-18 14:47:53 +00:00
"""
# PLMS sampler not supported yet, so ignore previous sampler
if self . sampler_name != ' ddim ' :
print ( f " sampler ' { self . sampler_name } ' is not yet supported. Using DDM sampler " )
2022-08-24 23:47:59 +00:00
sampler = DDIMSampler ( self . model , device = self . device )
2022-08-18 14:47:53 +00:00
else :
sampler = self . sampler
init_image = self . _load_img ( init_img ) . to ( self . device )
2022-08-18 16:43:59 +00:00
init_image = repeat ( init_image , ' 1 ... -> b ... ' , b = batch_size )
2022-08-23 17:49:17 +00:00
with precision_scope ( self . device . type ) :
2022-08-24 23:47:59 +00:00
init_latent = self . model . get_first_stage_encoding ( self . model . encode_first_stage ( init_image ) ) # move to latent space
2022-08-18 14:47:53 +00:00
sampler . make_schedule ( ddim_num_steps = steps , ddim_eta = ddim_eta , verbose = False )
t_enc = int ( strength * steps )
2022-08-25 04:42:37 +00:00
# print(f"target t_enc is {t_enc} steps")
2022-08-18 14:47:53 +00:00
images = list ( )
2022-08-23 04:51:38 +00:00
try :
2022-08-24 23:47:59 +00:00
with precision_scope ( self . device . type ) , self . model . ema_scope ( ) :
2022-08-23 17:49:17 +00:00
all_samples = list ( )
for n in trange ( iterations , desc = " Sampling " ) :
seed_everything ( seed )
for prompts in tqdm ( data , desc = " data " , dynamic_ncols = True ) :
uc = None
if cfg_scale != 1.0 :
2022-08-24 23:47:59 +00:00
uc = self . model . get_learned_conditioning ( batch_size * [ " " ] )
2022-08-23 17:49:17 +00:00
if isinstance ( prompts , tuple ) :
prompts = list ( prompts )
# weighted sub-prompts
subprompts , weights = T2I . _split_weighted_subprompts ( prompts [ 0 ] )
if len ( subprompts ) > 1 :
# i dont know if this is correct.. but it works
c = torch . zeros_like ( uc )
# get total weight for normalizing
totalWeight = sum ( weights )
# normalize each "sub prompt" and add it
for i in range ( 0 , len ( subprompts ) ) :
weight = weights [ i ]
if not skip_normalize :
weight = weight / totalWeight
2022-08-24 23:47:59 +00:00
c = torch . add ( c , self . model . get_learned_conditioning ( subprompts [ i ] ) , alpha = weight )
2022-08-23 17:49:17 +00:00
else : # just standard 1 prompt
2022-08-24 23:47:59 +00:00
c = self . model . get_learned_conditioning ( prompts )
2022-08-23 17:49:17 +00:00
# encode (scaled latent)
z_enc = sampler . stochastic_encode ( init_latent , torch . tensor ( [ t_enc ] * batch_size ) . to ( self . device ) )
# decode it
samples = sampler . decode ( z_enc , c , t_enc , unconditional_guidance_scale = cfg_scale ,
unconditional_conditioning = uc , )
2022-08-24 23:47:59 +00:00
x_samples = self . model . decode_first_stage ( samples )
2022-08-23 17:49:17 +00:00
x_samples = torch . clamp ( ( x_samples + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
2022-08-24 21:52:34 +00:00
for x_sample in x_samples :
x_sample = 255. * rearrange ( x_sample . cpu ( ) . numpy ( ) , ' c h w -> h w c ' )
image = Image . fromarray ( x_sample . astype ( np . uint8 ) )
2022-08-26 02:57:30 +00:00
if gfpgan_strength > 0 :
image = self . _run_gfpgan ( image , gfpgan_strength )
2022-08-24 21:52:34 +00:00
images . append ( [ image , seed ] )
if callback is not None :
callback ( image , seed )
2022-08-23 17:49:17 +00:00
seed = self . _new_seed ( )
2022-08-23 04:51:38 +00:00
except KeyboardInterrupt :
2022-08-23 05:23:14 +00:00
print ( ' *interrupted* ' )
2022-08-23 04:51:38 +00:00
print ( ' Partial results will be returned; if --grid was requested, nothing will be returned. ' )
except RuntimeError as e :
2022-08-24 16:57:04 +00:00
print ( " Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion " )
traceback . print_exc ( )
2022-08-18 14:47:53 +00:00
return images
2022-08-17 01:34:37 +00:00
def _new_seed ( self ) :
self . seed = random . randrange ( 0 , np . iinfo ( np . uint32 ) . max )
return self . seed
def load_model ( self ) :
""" Load and initialize the model from configuration variables passed at object creation time """
if self . model is None :
seed_everything ( self . seed )
try :
config = OmegaConf . load ( self . config )
2022-08-23 17:49:17 +00:00
self . device = torch . device ( self . device ) if torch . cuda . is_available ( ) else torch . device ( " cpu " )
2022-08-17 01:34:37 +00:00
model = self . _load_model_from_config ( config , self . weights )
2022-08-24 15:29:32 +00:00
if self . embedding_path is not None :
2022-08-24 02:45:02 +00:00
model . embedding_manager . load ( self . embedding_path )
2022-08-17 01:34:37 +00:00
self . model = model . to ( self . device )
2022-08-24 05:39:25 +00:00
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
self . model . cond_stage_model . device = self . device
2022-08-17 01:34:37 +00:00
except AttributeError :
raise SystemExit
2022-08-23 21:25:39 +00:00
msg = f ' setting sampler to { self . sampler_name } '
2022-08-17 01:34:37 +00:00
if self . sampler_name == ' plms ' :
2022-08-24 17:14:08 +00:00
self . sampler = PLMSSampler ( self . model , device = self . device )
2022-08-17 01:34:37 +00:00
elif self . sampler_name == ' ddim ' :
2022-08-24 17:14:08 +00:00
self . sampler = DDIMSampler ( self . model , device = self . device )
2022-08-23 21:25:39 +00:00
elif self . sampler_name == ' k_dpm_2_a ' :
2022-08-25 17:04:57 +00:00
self . sampler = KSampler ( self . model , ' dpm_2_ancestral ' , device = self . device )
2022-08-23 21:25:39 +00:00
elif self . sampler_name == ' k_dpm_2 ' :
2022-08-25 17:04:57 +00:00
self . sampler = KSampler ( self . model , ' dpm_2 ' , device = self . device )
2022-08-23 21:25:39 +00:00
elif self . sampler_name == ' k_euler_a ' :
2022-08-25 17:04:57 +00:00
self . sampler = KSampler ( self . model , ' euler_ancestral ' , device = self . device )
2022-08-23 21:25:39 +00:00
elif self . sampler_name == ' k_euler ' :
2022-08-25 17:04:57 +00:00
self . sampler = KSampler ( self . model , ' euler ' , device = self . device )
2022-08-23 21:25:39 +00:00
elif self . sampler_name == ' k_heun ' :
2022-08-25 17:04:57 +00:00
self . sampler = KSampler ( self . model , ' heun ' , device = self . device )
2022-08-23 21:25:39 +00:00
elif self . sampler_name == ' k_lms ' :
2022-08-25 17:04:57 +00:00
self . sampler = KSampler ( self . model , ' lms ' , device = self . device )
2022-08-17 01:34:37 +00:00
else :
2022-08-23 21:25:39 +00:00
msg = f ' unsupported sampler { self . sampler_name } , defaulting to plms '
2022-08-24 17:14:08 +00:00
self . sampler = PLMSSampler ( self . model , device = self . device )
2022-08-17 01:34:37 +00:00
2022-08-23 21:25:39 +00:00
print ( msg )
2022-08-17 01:34:37 +00:00
return self . model
def _load_model_from_config ( self , config , ckpt ) :
print ( f " Loading model from { ckpt } " )
pl_sd = torch . load ( ckpt , map_location = " cpu " )
2022-08-25 04:42:37 +00:00
# if "global_step" in pl_sd:
# print(f"Global Step: {pl_sd['global_step']}")
2022-08-17 01:34:37 +00:00
sd = pl_sd [ " state_dict " ]
model = instantiate_from_config ( config . model )
m , u = model . load_state_dict ( sd , strict = False )
2022-08-25 16:18:35 +00:00
model . to ( self . device )
2022-08-17 01:34:37 +00:00
model . eval ( )
2022-08-21 23:57:48 +00:00
if self . full_precision :
2022-08-22 00:16:31 +00:00
print ( ' Using slower but more accurate full-precision math (--full_precision) ' )
2022-08-21 23:57:48 +00:00
else :
2022-08-22 04:12:16 +00:00
print ( ' Using half precision math. Call with --full_precision to use slower but more accurate full precision. ' )
2022-08-21 23:57:48 +00:00
model . half ( )
2022-08-17 01:34:37 +00:00
return model
2022-08-18 14:47:53 +00:00
def _load_img ( self , path ) :
image = Image . open ( path ) . convert ( " RGB " )
w , h = image . size
print ( f " loaded input image of size ( { w } , { h } ) from { path } " )
w , h = map ( lambda x : x - x % 32 , ( w , h ) ) # resize to integer multiple of 32
2022-08-24 17:14:08 +00:00
image = image . resize ( ( w , h ) , resample = Image . Resampling . LANCZOS )
2022-08-18 14:47:53 +00:00
image = np . array ( image ) . astype ( np . float32 ) / 255.0
image = image [ None ] . transpose ( 0 , 3 , 1 , 2 )
image = torch . from_numpy ( image )
return 2. * image - 1.
2022-08-23 03:56:36 +00:00
2022-08-23 05:23:14 +00:00
def _split_weighted_subprompts ( text ) :
"""
grabs all text up to the first occurrence of ' : '
uses the grabbed text as a sub - prompt , and takes the value following ' : ' as weight
if ' : ' has no value defined , defaults to 1.0
repeats until no text remaining
"""
2022-08-22 13:59:06 +00:00
remaining = len ( text )
prompts = [ ]
weights = [ ]
while remaining > 0 :
if " : " in text :
2022-08-22 14:32:01 +00:00
idx = text . index ( " : " ) # first occurrence from start
# grab up to index as sub-prompt
2022-08-22 13:59:06 +00:00
prompt = text [ : idx ]
remaining - = idx
# remove from main text
text = text [ idx + 1 : ]
2022-08-22 14:32:01 +00:00
# find value for weight
2022-08-22 13:59:06 +00:00
if " " in text :
2022-08-22 14:32:01 +00:00
idx = text . index ( " " ) # first occurence
2022-08-22 13:59:06 +00:00
else : # no space, read to end
idx = len ( text )
if idx != 0 :
2022-08-22 14:32:01 +00:00
try :
weight = float ( text [ : idx ] )
except : # couldn't treat as float
print ( f " Warning: ' { text [ : idx ] } ' is not a value, are you missing a space? " )
weight = 1.0
else : # no value found
2022-08-22 13:59:06 +00:00
weight = 1.0
2022-08-22 14:32:01 +00:00
# remove from main text
2022-08-22 13:59:06 +00:00
remaining - = idx
text = text [ idx + 1 : ]
2022-08-22 14:32:01 +00:00
# append the sub-prompt and its weight
2022-08-22 13:59:06 +00:00
prompts . append ( prompt )
weights . append ( weight )
2022-08-22 14:32:01 +00:00
else : # no : found
if len ( text ) > 0 : # there is still text though
# take remainder as weight 1
2022-08-22 13:59:06 +00:00
prompts . append ( text )
weights . append ( 1.0 )
remaining = 0
2022-08-23 05:23:14 +00:00
return prompts , weights
2022-08-26 02:57:30 +00:00
def _run_gfpgan ( self , image , strength ) :
if ( self . gfpgan is None ) :
print ( f " GFPGAN not initialized, it must be loaded via the --gfpgan argument " )
return image
image = image . convert ( " RGB " )
cropped_faces , restored_faces , restored_img = self . gfpgan . enhance ( np . array ( image , dtype = np . uint8 ) , has_aligned = False , only_center_face = False , paste_back = True )
res = Image . fromarray ( restored_img )
if strength < 1.0 :
res = Image . blend ( image , res , strength )
return res