mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for runwayML custom inpainting model
This is still a work in progress but seems functional. It supports inpainting, txt2img and img2img on the ddim and k* samplers (plms still needs work, but I know what to do). To test this, get the file `sd-v1-5-inpainting.ckpt' from https://huggingface.co/runwayml/stable-diffusion-inpainting and place it at `models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt` Launch invoke.py with --model inpainting-1.5 and proceed as usual. Caveats: 1. The inpainting model takes about 800 Mb more memory than the standard 1.5 model. This model will not work on 4 GB cards. 2. The inpainting model is temperamental. It wants you to describe the entire scene and not just the masked area to replace. So if you want to replace the parrot on a man's shoulder with a crow, the prompt "crow" may fail. Try "man with a crow on shoulder" instead. The symptom of a failed inpainting is that the area will be erased and replaced with background. 3. This has not been tested well. Please report bugs.
This commit is contained in:
@ -12,6 +12,22 @@ from ldm.modules.diffusionmodules.util import (
|
||||
extract_into_tensor,
|
||||
)
|
||||
|
||||
def make_cond_in(uncond, cond):
|
||||
if isinstance(cond, dict):
|
||||
assert isinstance(uncond, dict)
|
||||
cond_in = dict()
|
||||
for k in cond:
|
||||
if isinstance(cond[k], list):
|
||||
cond_in[k] = [
|
||||
torch.cat([uncond[k][i], cond[k][i]])
|
||||
for i in range(len(cond[k]))
|
||||
]
|
||||
else:
|
||||
cond_in[k] = torch.cat([uncond[k], cond[k]])
|
||||
else:
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
return cond_in
|
||||
|
||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
if threshold <= 0.0:
|
||||
return result
|
||||
@ -37,7 +53,7 @@ class CFGDenoiser(nn.Module):
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
cond_in = make_cond_in(uncond,cond)
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
@ -64,13 +80,12 @@ class KSampler(Sampler):
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
cond_in = make_cond_in(uncond, cond)
|
||||
uncond, cond = self.inner_model(
|
||||
x_in, sigma_in, cond=cond_in
|
||||
).chunk(2)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
@ -283,3 +298,4 @@ class KSampler(Sampler):
|
||||
|
||||
def conditioning_key(self)->str:
|
||||
return self.model.inner_model.model.conditioning_key
|
||||
|
||||
|
Reference in New Issue
Block a user