mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
inpainting for the normal model [WIP]
This seems to be performing well until the LAST STEP, at which point it dissolves to confetti.
This commit is contained in:
parent
b2664e807e
commit
f3570d8344
@ -22,6 +22,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
@ -61,10 +62,32 @@ class AddsMaskLatents:
|
|||||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||||
mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||||
model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w')
|
model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w')
|
||||||
# model_input = torch.cat([latents, mask, mask_latents], dim=1)
|
|
||||||
return self.forward(model_input, t, text_embeddings)
|
return self.forward(model_input, t, text_embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AddsMaskGuidance:
|
||||||
|
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||||
|
mask: torch.FloatTensor
|
||||||
|
mask_latents: torch.FloatTensor
|
||||||
|
_scheduler: SchedulerMixin
|
||||||
|
_noise_func: Callable
|
||||||
|
_debug: Optional[Callable] = None
|
||||||
|
|
||||||
|
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
|
||||||
|
batch_size = latents.size(0)
|
||||||
|
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||||
|
noise = self._noise_func(self.mask_latents)
|
||||||
|
mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t[0]) # .to(dtype=mask_latents.dtype)
|
||||||
|
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||||
|
# if self._debug:
|
||||||
|
# self._debug(latents, f"t={t[0]} latents")
|
||||||
|
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||||
|
if self._debug:
|
||||||
|
self._debug(masked_input, f"t={t[0]} lerped")
|
||||||
|
return self.forward(masked_input, t, text_embeddings)
|
||||||
|
|
||||||
|
|
||||||
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
|
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -382,17 +405,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||||
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
||||||
|
|
||||||
if is_inpainting_model(self.unet):
|
if mask.dim() == 3:
|
||||||
if mask.dim() == 3:
|
mask = mask.unsqueeze(0)
|
||||||
mask = mask.unsqueeze(0)
|
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) \
|
||||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)\
|
.to(device=device, dtype=latents_dtype)
|
||||||
.to(device=device, dtype=latents_dtype)
|
|
||||||
|
|
||||||
|
if is_inpainting_model(self.unet):
|
||||||
self.invokeai_diffuser.model_forward_callback = \
|
self.invokeai_diffuser.model_forward_callback = \
|
||||||
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
|
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
|
||||||
else:
|
else:
|
||||||
# FIXME: need to add guidance that applies mask
|
self.invokeai_diffuser.model_forward_callback = \
|
||||||
pass
|
AddsMaskGuidance(self._unet_forward, mask, init_image_latents,
|
||||||
|
self.scheduler, noise_func) # self.debug_latents)
|
||||||
|
|
||||||
result = None
|
result = None
|
||||||
|
|
||||||
@ -417,7 +441,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
init_image = init_image.to(device=device, dtype=dtype)
|
init_image = init_image.to(device=device, dtype=dtype)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||||
init_latents = init_latent_dist.sample() # FIXME: uses torch.randn. make reproducible!
|
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||||
init_latents = 0.18215 * init_latents
|
init_latents = 0.18215 * init_latents
|
||||||
|
|
||||||
noise = noise_func(init_latents)
|
noise = noise_func(init_latents)
|
||||||
@ -456,3 +480,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
def channels(self) -> int:
|
def channels(self) -> int:
|
||||||
"""Compatible with DiffusionWrapper"""
|
"""Compatible with DiffusionWrapper"""
|
||||||
return self.unet.in_channels
|
return self.unet.in_channels
|
||||||
|
|
||||||
|
def debug_latents(self, latents, msg):
|
||||||
|
with torch.inference_mode():
|
||||||
|
from ldm.util import debug_image
|
||||||
|
decoded = self.numpy_to_pil(self.decode_latents(latents))
|
||||||
|
for i, img in enumerate(decoded):
|
||||||
|
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user