This commit is contained in:
Damian at mba 2022-10-21 15:07:11 +02:00
parent e574a1574f
commit 64051d081c
6 changed files with 22 additions and 201 deletions

View File

@ -527,7 +527,7 @@ def parameters_to_generated_image_metadata(parameters):
rfc_dict["sampler"] = parameters["sampler_name"] rfc_dict["sampler"] = parameters["sampler_name"]
# display weighted subprompts (liable to change) # display weighted subprompts (liable to change)
subprompts = split_weighted_subprompts(parameters["prompt"], skip_normalize=True) subprompts = split_weighted_subprompts(parameters["prompt"])
subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts] subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
rfc_dict["prompt"] = subprompts rfc_dict["prompt"] = subprompts

View File

@ -1,186 +0,0 @@
import random
import traceback
import numpy as np
import torch
from diffusers import (LMSDiscreteScheduler)
from PIL import Image
from torch import autocast
from tqdm.auto import tqdm
import .ldm.models.diffusion.cross_attention
@torch.no_grad()
def stablediffusion(
clip,
clip_tokenizer,
device,
vae,
unet,
prompt='',
prompt_edit=None,
prompt_edit_token_weights=None,
prompt_edit_tokens_start=0.0,
prompt_edit_tokens_end=1.0,
prompt_edit_spatial_start=0.0,
prompt_edit_spatial_end=1.0,
guidance_scale=7.5,
steps=50,
seed=None,
width=512,
height=512,
init_image=None,
init_image_strength=0.5,
):
if prompt_edit_token_weights is None:
prompt_edit_token_weights = []
# Change size to multiple of 64 to prevent size mismatches inside model
width = width - width % 64
height = height - height % 64
# If seed is None, randomly select seed from 0 to 2^32-1
if seed is None: seed = random.randrange(2**32 - 1)
generator = torch.manual_seed(seed)
# Set inference timesteps to scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085,
beta_end=0.012,
beta_schedule='scaled_linear',
num_train_timesteps=1000,
)
scheduler.set_timesteps(steps)
# Preprocess image if it exists (img2img)
if init_image is not None:
# Resize and transpose for numpy b h w c -> torch b c h w
init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS)
init_image = np.array(init_image).astype(np.float32) / 255.0 * 2.0 - 1.0
init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2))
# If there is alpha channel, composite alpha for white, as the diffusion
# model does not support alpha channel
if init_image.shape[1] > 3:
init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:])
# Move image to GPU
init_image = init_image.to(device)
# Encode image
with autocast(device):
init_latent = (vae.encode(init_image)
.latent_dist
.sample(generator=generator)
* 0.18215)
t_start = steps - int(steps * init_image_strength)
else:
init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8),
device=device)
t_start = 0
# Generate random normal noise
noise = torch.randn(init_latent.shape, generator=generator, device=device)
latent = scheduler.add_noise(init_latent,
noise,
torch.tensor([scheduler.timesteps[t_start]], device=device)
).to(device)
# Process clip
with autocast(device):
tokens_uncond = clip_tokenizer('', padding='max_length',
max_length=clip_tokenizer.model_max_length,
truncation=True, return_tensors='pt',
return_overflowing_tokens=True
)
embedding_uncond = clip(tokens_uncond.input_ids.to(device)).last_hidden_state
tokens_cond = clip_tokenizer(prompt, padding='max_length',
max_length=clip_tokenizer.model_max_length,
truncation=True, return_tensors='pt',
return_overflowing_tokens=True
)
embedding_cond = clip(tokens_cond.input_ids.to(device)).last_hidden_state
# Process prompt editing
if prompt_edit is not None:
tokens_cond_edit = clip_tokenizer(prompt_edit, padding='max_length',
max_length=clip_tokenizer.model_max_length,
truncation=True, return_tensors='pt',
return_overflowing_tokens=True
)
embedding_cond_edit = clip(tokens_cond_edit.input_ids.to(device)).last_hidden_state
c_a_c.init_attention_edit(tokens_cond, tokens_cond_edit)
c_a_c.init_attention_func()
c_a_c.init_attention_weights(prompt_edit_token_weights)
timesteps = scheduler.timesteps[t_start:]
for idx, timestep in tqdm(enumerate(timesteps), total=len(timesteps)):
t_index = t_start + idx
latent_model_input = latent
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
# Predict the unconditional noise residual
noise_pred_uncond = unet(latent_model_input,
timestep,
encoder_hidden_states=embedding_uncond
).sample
# Prepare the Cross-Attention layers
if prompt_edit is not None:
c_a_c.save_last_tokens_attention()
c_a_c.save_last_self_attention()
else:
# Use weights on non-edited prompt when edit is None
c_a_c.use_last_tokens_attention_weights()
# Predict the conditional noise residual and save the
# cross-attention layer activations
noise_pred_cond = unet(latent_model_input,
timestep,
encoder_hidden_states=embedding_cond
).sample
# Edit the Cross-Attention layer activations
if prompt_edit is not None:
t_scale = timestep / scheduler.num_train_timesteps
if (t_scale >= prompt_edit_tokens_start
and t_scale <= prompt_edit_tokens_end):
c_a_c.use_last_tokens_attention()
if (t_scale >= prompt_edit_spatial_start
and t_scale <= prompt_edit_spatial_end):
c_a_c.use_last_self_attention()
# Use weights on edited prompt
c_a_c.use_last_tokens_attention_weights()
# Predict the edited conditional noise residual using the
# cross-attention masks
noise_pred_cond = unet(latent_model_input,
timestep,
encoder_hidden_states=embedding_cond_edit
).sample
# Perform guidance
noise_pred = (noise_pred_uncond + guidance_scale
* (noise_pred_cond - noise_pred_uncond))
latent = scheduler.step(noise_pred,
t_index,
latent
).prev_sample
# scale and decode the image latents with vae
latent = latent / 0.18215
image = vae.decode(latent.to(vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = (image[0] * 255).round().astype('uint8')
return Image.fromarray(image)

View File

@ -51,9 +51,8 @@ class Img2Img(Generator):
img_callback = step_callback, img_callback = step_callback,
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
init_latent = self.init_latent, init_latent = self.init_latent, # changes how noising is performed in ksampler
extra_conditioning_info = extra_conditioning_info extra_conditioning_info = extra_conditioning_info
# changes how noising is performed in ksampler
) )
return self.sample_to_image(samples) return self.sample_to_image(samples)

View File

@ -29,9 +29,9 @@ work fine.
import torch import torch
import numpy as np import numpy as np
from models.clipseg import CLIPDensePredT from clipseg_models.clipseg import CLIPDensePredT
from einops import rearrange, repeat from einops import rearrange, repeat
from PIL import Image from PIL import Image, ImageOps
from torchvision import transforms from torchvision import transforms
CLIP_VERSION = 'ViT-B/16' CLIP_VERSION = 'ViT-B/16'
@ -50,9 +50,14 @@ class SegmentedGrayscale(object):
discrete_heatmap = self.heatmap.lt(threshold).int() discrete_heatmap = self.heatmap.lt(threshold).int()
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L')) return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
def to_transparent(self)->Image: def to_transparent(self,invert:bool=False)->Image:
transparent_image = self.image.copy() transparent_image = self.image.copy()
transparent_image.putalpha(self.to_grayscale()) gs = self.to_grayscale()
# The following line looks like a bug, but isn't.
# For img2img, we want the selected regions to be transparent,
# but to_grayscale() returns the opposite.
gs = ImageOps.invert(gs) if not invert else gs
transparent_image.putalpha(gs)
return transparent_image return transparent_image
# unscales and uncrops the 352x352 heatmap so that it matches the image again # unscales and uncrops the 352x352 heatmap so that it matches the image again
@ -79,7 +84,7 @@ class Txt2Mask(object):
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False) self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
@torch.no_grad() @torch.no_grad()
def segment(self, image:Image, prompt:str) -> SegmentedGrayscale: def segment(self, image, prompt:str) -> SegmentedGrayscale:
''' '''
Given a prompt string such as "a bagel", tries to identify the object in the Given a prompt string such as "a bagel", tries to identify the object in the
provided image and returns a SegmentedGrayscale object in which the brighter provided image and returns a SegmentedGrayscale object in which the brighter
@ -94,6 +99,10 @@ class Txt2Mask(object):
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64... transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
]) ])
if type(image) is str:
image = Image.open(image).convert('RGB')
image = ImageOps.exif_transpose(image)
img = self._scale_and_crop(image) img = self._scale_and_crop(image)
img = transform(img).unsqueeze(0) img = transform(img).unsqueeze(0)

View File

@ -1,5 +1,4 @@
"""SAMPLING ONLY.""" """SAMPLING ONLY."""
from typing import Union
import torch import torch
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
@ -29,7 +28,7 @@ class DDIMSampler(Sampler):
def p_sample( def p_sample(
self, self,
x, x,
c: Union[torch.Tensor, list], c,
t, t,
index, index,
repeat_noise=False, repeat_noise=False,

View File

@ -8,7 +8,7 @@ import numpy as np
from einops import rearrange from einops import rearrange
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
#from ldm.modules.attention import LinearAttention from ldm.modules.attention import LinearAttention
import psutil import psutil
@ -151,10 +151,10 @@ class ResnetBlock(nn.Module):
return x + h return x + h
#class LinAttnBlock(LinearAttention): class LinAttnBlock(LinearAttention):
# """to match AttnBlock usage""" """to match AttnBlock usage"""
# def __init__(self, in_channels): def __init__(self, in_channels):
# super().__init__(dim=in_channels, heads=1, dim_head=in_channels) super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock(nn.Module): class AttnBlock(nn.Module):