mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for runwayML custom inpainting model
This is still a work in progress but seems functional. It supports inpainting, txt2img and img2img on the ddim and k* samplers (plms still needs work, but I know what to do). To test this, get the file `sd-v1-5-inpainting.ckpt' from https://huggingface.co/runwayml/stable-diffusion-inpainting and place it at `models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt` Launch invoke.py with --model inpainting-1.5 and proceed as usual. Caveats: 1. The inpainting model takes about 800 Mb more memory than the standard 1.5 model. This model will not work on 4 GB cards. 2. The inpainting model is temperamental. It wants you to describe the entire scene and not just the masked area to replace. So if you want to replace the parrot on a man's shoulder with a crow, the prompt "crow" may fail. Try "man with a crow on shoulder" instead. The symptom of a failed inpainting is that the area will be erased and replaced with background. 3. This has not been tested well. Please report bugs.
This commit is contained in:
parent
aaf7a4f1d3
commit
b101be041b
@ -13,6 +13,13 @@ stable-diffusion-1.4:
|
|||||||
default: true
|
default: true
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
|
inpainting-1.5:
|
||||||
|
description: runwayML tuned inpainting model v1.5
|
||||||
|
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
||||||
|
config: configs/stable-diffusion/v1-inpainting-inference.yaml
|
||||||
|
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
|
width: 512
|
||||||
|
height: 512
|
||||||
stable-diffusion-1.5:
|
stable-diffusion-1.5:
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||||
|
@ -4,6 +4,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
from ldm.invoke.devices import choose_autocast
|
||||||
from ldm.invoke.generator.base import downsampling
|
from ldm.invoke.generator.base import downsampling
|
||||||
from ldm.invoke.generator.img2img import Img2Img
|
from ldm.invoke.generator.img2img import Img2Img
|
||||||
from ldm.invoke.generator.txt2img import Txt2Img
|
from ldm.invoke.generator.txt2img import Txt2Img
|
||||||
@ -60,7 +61,7 @@ class Omnibus(Img2Img,Txt2Img):
|
|||||||
) # move to latent space
|
) # move to latent space
|
||||||
|
|
||||||
# create a completely black mask (1s)
|
# create a completely black mask (1s)
|
||||||
mask_image = torch.ones(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device)
|
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
|
||||||
# and the masked image is just a copy of the original
|
# and the masked image is just a copy of the original
|
||||||
masked_image = init_image
|
masked_image = init_image
|
||||||
t_enc = int(strength * steps)
|
t_enc = int(strength * steps)
|
||||||
@ -74,7 +75,8 @@ class Omnibus(Img2Img,Txt2Img):
|
|||||||
|
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.autocast("cuda"):
|
scope = choose_autocast(self.precision)
|
||||||
|
with scope(self.model.device.type):
|
||||||
|
|
||||||
batch = self.make_batch_sd(
|
batch = self.make_batch_sd(
|
||||||
init_image,
|
init_image,
|
||||||
@ -139,15 +141,10 @@ class Omnibus(Img2Img,Txt2Img):
|
|||||||
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||||
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||||
}
|
}
|
||||||
print(f'DEBUG: image = {batch["image"]} shape={batch["image"].shape}')
|
|
||||||
print(f'DEBUG: mask = {batch["mask"]} shape={batch["mask"].shape}')
|
|
||||||
print(f'DEBUG: masked_image = {batch["masked_image"]} shape={batch["masked_image"].shape}')
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def get_noise(self, width:int, height:int):
|
def get_noise(self, width:int, height:int):
|
||||||
if self.init_latent:
|
if self.init_latent:
|
||||||
print('DEBUG: returning Img2Img.getnoise()')
|
|
||||||
return super(Img2Img,self).get_noise(width,height)
|
return super(Img2Img,self).get_noise(width,height)
|
||||||
else:
|
else:
|
||||||
print('DEBUG: returning Txt2Img.getnoise()')
|
|
||||||
return super(Txt2Img,self).get_noise(width,height)
|
return super(Txt2Img,self).get_noise(width,height)
|
||||||
|
@ -12,6 +12,22 @@ from ldm.modules.diffusionmodules.util import (
|
|||||||
extract_into_tensor,
|
extract_into_tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def make_cond_in(uncond, cond):
|
||||||
|
if isinstance(cond, dict):
|
||||||
|
assert isinstance(uncond, dict)
|
||||||
|
cond_in = dict()
|
||||||
|
for k in cond:
|
||||||
|
if isinstance(cond[k], list):
|
||||||
|
cond_in[k] = [
|
||||||
|
torch.cat([uncond[k][i], cond[k][i]])
|
||||||
|
for i in range(len(cond[k]))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
cond_in[k] = torch.cat([uncond[k], cond[k]])
|
||||||
|
else:
|
||||||
|
cond_in = torch.cat([uncond, cond])
|
||||||
|
return cond_in
|
||||||
|
|
||||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||||
if threshold <= 0.0:
|
if threshold <= 0.0:
|
||||||
return result
|
return result
|
||||||
@ -37,7 +53,7 @@ class CFGDenoiser(nn.Module):
|
|||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
cond_in = torch.cat([uncond, cond])
|
cond_in = make_cond_in(uncond,cond)
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||||
if self.warmup < self.warmup_max:
|
if self.warmup < self.warmup_max:
|
||||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||||
@ -64,13 +80,12 @@ class KSampler(Sampler):
|
|||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
cond_in = torch.cat([uncond, cond])
|
cond_in = make_cond_in(uncond, cond)
|
||||||
uncond, cond = self.inner_model(
|
uncond, cond = self.inner_model(
|
||||||
x_in, sigma_in, cond=cond_in
|
x_in, sigma_in, cond=cond_in
|
||||||
).chunk(2)
|
).chunk(2)
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
|
|
||||||
def make_schedule(
|
def make_schedule(
|
||||||
self,
|
self,
|
||||||
ddim_num_steps,
|
ddim_num_steps,
|
||||||
@ -283,3 +298,4 @@ class KSampler(Sampler):
|
|||||||
|
|
||||||
def conditioning_key(self)->str:
|
def conditioning_key(self)->str:
|
||||||
return self.model.inner_model.model.conditioning_key
|
return self.model.inner_model.model.conditioning_key
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user