mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
inpainting for the normal model [WIP]
This seems to be performing well until the LAST STEP, at which point it dissolves to confetti.
This commit is contained in:
parent
b2664e807e
commit
f3570d8344
@ -22,6 +22,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
@ -61,10 +62,32 @@ class AddsMaskLatents:
|
||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w')
|
||||
# model_input = torch.cat([latents, mask, mask_latents], dim=1)
|
||||
return self.forward(model_input, t, text_embeddings)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
mask: torch.FloatTensor
|
||||
mask_latents: torch.FloatTensor
|
||||
_scheduler: SchedulerMixin
|
||||
_noise_func: Callable
|
||||
_debug: Optional[Callable] = None
|
||||
|
||||
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
noise = self._noise_func(self.mask_latents)
|
||||
mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t[0]) # .to(dtype=mask_latents.dtype)
|
||||
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
# if self._debug:
|
||||
# self._debug(latents, f"t={t[0]} latents")
|
||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||
if self._debug:
|
||||
self._debug(masked_input, f"t={t[0]} lerped")
|
||||
return self.forward(masked_input, t, text_embeddings)
|
||||
|
||||
|
||||
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
@ -382,17 +405,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
||||
|
||||
if is_inpainting_model(self.unet):
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)\
|
||||
.to(device=device, dtype=latents_dtype)
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) \
|
||||
.to(device=device, dtype=latents_dtype)
|
||||
|
||||
if is_inpainting_model(self.unet):
|
||||
self.invokeai_diffuser.model_forward_callback = \
|
||||
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
|
||||
else:
|
||||
# FIXME: need to add guidance that applies mask
|
||||
pass
|
||||
self.invokeai_diffuser.model_forward_callback = \
|
||||
AddsMaskGuidance(self._unet_forward, mask, init_image_latents,
|
||||
self.scheduler, noise_func) # self.debug_latents)
|
||||
|
||||
result = None
|
||||
|
||||
@ -417,7 +441,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
with torch.inference_mode():
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample() # FIXME: uses torch.randn. make reproducible!
|
||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
noise = noise_func(init_latents)
|
||||
@ -456,3 +480,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
def channels(self) -> int:
|
||||
"""Compatible with DiffusionWrapper"""
|
||||
return self.unet.in_channels
|
||||
|
||||
def debug_latents(self, latents, msg):
|
||||
with torch.inference_mode():
|
||||
from ldm.util import debug_image
|
||||
decoded = self.numpy_to_pil(self.decode_latents(latents))
|
||||
for i, img in enumerate(decoded):
|
||||
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)
|
||||
|
Loading…
Reference in New Issue
Block a user