Merge branch 'main' of https://github.com/BaristaLabs/stable-diffusion-dream into add-gfpgan-option

This commit is contained in:
Sean McLellan 2022-08-25 23:19:17 -04:00
commit dbb9132f4d
4 changed files with 136 additions and 163 deletions

View File

@ -105,11 +105,18 @@ So for instance, to apply the maximum strength:
dream> a man wearing a pineapple hat -G 1 dream> a man wearing a pineapple hat -G 1
~~~~ ~~~~
This also works with img2img:
~~~
dream> a man wearing a pineapple hat -I path/to/your/file.png -G 1
~~~
That's it! That's it!
There's also a bunch of options to control GFPGAN settings when starting the script for different configs that you can There's also a bunch of options to control GFPGAN settings when starting the script for different configs that you can
read about in the help text. This will let you control where GFPGAN is installed, if upsampling is enapled, the upsampler to use and the model path. read about in the help text. This will let you control where GFPGAN is installed, if upsampling is enapled, the upsampler to use and the model path.
Note that loading GFPGAN consumes additional GPU memory, additionaly, a couple of seconds will be tacked on when generating your images.
## Barebones Web Server ## Barebones Web Server
As of version 1.10, this distribution comes with a bare bones web server (see screenshot). To use it, As of version 1.10, this distribution comes with a bare bones web server (see screenshot). To use it,

View File

@ -220,7 +220,7 @@ The vast majority of these arguments default to reasonable values.
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
callback // a function or method that will be called each time an image is generated image_callback // a function or method that will be called each time an image is generated
To use the callback, define a function of method that receives two arguments, an Image object To use the callback, define a function of method that receives two arguments, an Image object
and the seed. You can then do whatever you like with the image, including converting it to and the seed. You can then do whatever you like with the image, including converting it to
@ -253,82 +253,71 @@ The vast majority of these arguments default to reasonable values.
height = h height = h
width = w width = w
data = [batch_size * [prompt]]
scope = autocast if self.precision=="autocast" else nullcontext scope = autocast if self.precision=="autocast" else nullcontext
tic = time.time() tic = time.time()
results = list()
try:
if init_img: if init_img:
assert os.path.exists(init_img),f'{init_img}: File not found' assert os.path.exists(init_img),f'{init_img}: File not found'
results = self._img2img(prompt, images_iterator = self._img2img(prompt,
data=data,precision_scope=scope, precision_scope=scope,
batch_size=batch_size,iterations=iterations, batch_size=batch_size,
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
init_img=init_img,strength=strength,
gfpgan_strength=gfpgan_strength,variants=variants,
callback=image_callback)
else:
results = self._txt2img(prompt,
data=data,precision_scope=scope,
batch_size=batch_size,iterations=iterations,
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize, skip_normalize=skip_normalize,
gfpgan_strength=gfpgan_strength, gfpgan_strength=gfpgan_strength,
width=width,height=height, init_img=init_img,strength=strength)
callback=image_callback) else:
images_iterator = self._txt2img(prompt,
precision_scope=scope,
batch_size=batch_size,
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
gfpgan_strength=gfpgan_strength,
width=width,height=height)
with scope(self.device.type), self.model.ema_scope():
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
iter_images = next(images_iterator)
for image in iter_images:
results.append([image, seed])
if image_callback is not None:
image_callback(image,seed)
seed = self._new_seed()
except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e:
print(str(e))
print('Are you sure your system has an adequate NVIDIA GPU?')
toc = time.time() toc = time.time()
print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic)) print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic))
return results return results
@torch.no_grad() @torch.no_grad()
def _txt2img(self,prompt, def _txt2img(self,
data,precision_scope, prompt,
batch_size,iterations, precision_scope,
steps,seed,cfg_scale,ddim_eta, batch_size,
steps,cfg_scale,ddim_eta,
skip_normalize, skip_normalize,
gfpgan_strength, gfpgan_strength,
width,height, width,height):
callback): # the callback is called each time a new Image is generated
""" """
Generate an image from the prompt, writing iteration images into the outdir An infinite iterator of images from the prompt.
The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...]
""" """
sampler = self.sampler sampler = self.sampler
images = list()
image_count = 0
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
try:
with precision_scope(self.device.type), self.model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = self.model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
# weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# get total weight for normalizing
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(0,len(subprompts)):
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = self.model.get_learned_conditioning(prompts)
while True:
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps, samples, _ = sampler.sample(S=steps,
conditioning=c, conditioning=c,
batch_size=batch_size, batch_size=batch_size,
shape=shape, shape=shape,
@ -336,39 +325,19 @@ The vast majority of these arguments default to reasonable values.
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
eta=ddim_eta) eta=ddim_eta)
yield self._samples_to_images(samples, gfpgan_strength=gfpgan_strength)
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8))
if gfpgan_strength > 0:
image = self._run_gfpgan(image, gfpgan_strength)
images.append([image,seed])
if callback is not None:
callback(image,seed)
seed = self._new_seed()
except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e:
print(str(e))
return images
@torch.no_grad() @torch.no_grad()
def _img2img(self,prompt, def _img2img(self,
data,precision_scope, prompt,
batch_size,iterations, precision_scope,
steps,seed,cfg_scale,ddim_eta, batch_size,
steps,cfg_scale,ddim_eta,
skip_normalize, skip_normalize,
gfpgan_strength, gfpgan_strength,
init_img,strength,variants, init_img,strength):
callback):
""" """
Generate an image from the prompt and the initial image, writing iteration images into the outdir An infinite iterator of images from the prompt and the initial image
The output is a list of lists in the format: [[image,seed1], [image,seed2],...]
""" """
# PLMS sampler not supported yet, so ignore previous sampler # PLMS sampler not supported yet, so ignore previous sampler
@ -387,22 +356,24 @@ The vast majority of these arguments default to reasonable values.
t_enc = int(strength * steps) t_enc = int(strength * steps)
# print(f"target t_enc is {t_enc} steps") # print(f"target t_enc is {t_enc} steps")
images = list()
try: while True:
with precision_scope(self.device.type), self.model.ema_scope(): uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
all_samples = list()
for n in trange(iterations, desc="Sampling"): # encode (scaled latent)
seed_everything(seed) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
for prompts in tqdm(data, desc="data", dynamic_ncols=True): # decode it
uc = None samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
if cfg_scale != 1.0: unconditional_conditioning=uc,)
yield self._samples_to_images(samples, gfpgan_strength)
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
def _get_uc_and_c(self, prompt, batch_size, skip_normalize):
uc = self.model.get_learned_conditioning(batch_size * [""]) uc = self.model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
# weighted sub-prompts # weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) subprompts,weights = T2I._split_weighted_subprompts(prompt)
if len(subprompts) > 1: if len(subprompts) > 1:
# i dont know if this is correct.. but it works # i dont know if this is correct.. but it works
c = torch.zeros_like(uc) c = torch.zeros_like(uc)
@ -413,35 +384,24 @@ The vast majority of these arguments default to reasonable values.
weight = weights[i] weight = weights[i]
if not skip_normalize: if not skip_normalize:
weight = weight / totalWeight weight = weight / totalWeight
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) c = torch.add(c, self.model.get_learned_conditioning(batch_size * [subprompts[i]]), alpha=weight)
else: # just standard 1 prompt else: # just standard 1 prompt
c = self.model.get_learned_conditioning(prompts) c = self.model.get_learned_conditioning(batch_size * [prompt])
return (uc, c)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
def _samples_to_images(self, samples, gfpgan_strength=0):
x_samples = self.model.decode_first_stage(samples) x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
images = list()
for x_sample in x_samples: for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8)) image = Image.fromarray(x_sample.astype(np.uint8))
try:
if gfpgan_strength > 0: if gfpgan_strength > 0:
image = self._run_gfpgan(image, gfpgan_strength) image = self._run_gfpgan(image, gfpgan_strength)
images.append([image,seed]) except Exception:
if callback is not None: print(f"Error running GFPGAN - Your image was not enhanced.")
callback(image,seed) images.append(image)
seed = self._new_seed()
except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e:
print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion")
traceback.print_exc()
return images return images
def _new_seed(self): def _new_seed(self):

View File

@ -301,7 +301,7 @@ def create_argv_parser():
'-o', '-o',
type=str, type=str,
default="outputs/img-samples", default="outputs/img-samples",
help="directory in which to place generated images and a log of prompts and seeds") help="directory in which to place generated images and a log of prompts and seeds (outputs/img-samples")
parser.add_argument('--embedding_path', parser.add_argument('--embedding_path',
type=str, type=str,
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line") help="Path to a pre-trained embedding manager checkpoint - can only be set on command line")

View File

@ -85,6 +85,12 @@ class DreamServer(BaseHTTPRequestHandler):
print(f"Prompt generated with output: {outputs}") print(f"Prompt generated with output: {outputs}")
post_data['initimg'] = '' # Don't send init image back post_data['initimg'] = '' # Don't send init image back
# Append post_data to log
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
for output in outputs:
log.write(f"{output[0]}: {json.dumps(post_data)}\n")
outputs = [x + [post_data] for x in outputs] # Append config to each output outputs = [x + [post_data] for x in outputs] # Append config to each output
result = {'outputs': outputs} result = {'outputs': outputs}
self.wfile.write(bytes(json.dumps(result), "utf-8")) self.wfile.write(bytes(json.dumps(result), "utf-8"))