mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
56 lines
1.5 KiB
Python
56 lines
1.5 KiB
Python
|
from typing import Callable
|
||
|
|
||
|
import torch
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from invokeai.backend.flux.inpaint import merge_intermediate_latents_with_init_latents
|
||
|
from invokeai.backend.flux.model import Flux
|
||
|
|
||
|
|
||
|
def denoise(
|
||
|
model: Flux,
|
||
|
# model input
|
||
|
img: torch.Tensor,
|
||
|
img_ids: torch.Tensor,
|
||
|
txt: torch.Tensor,
|
||
|
txt_ids: torch.Tensor,
|
||
|
vec: torch.Tensor,
|
||
|
# sampling parameters
|
||
|
timesteps: list[float],
|
||
|
step_callback: Callable[[], None],
|
||
|
guidance: float,
|
||
|
# For inpainting:
|
||
|
init_latents: torch.Tensor | None,
|
||
|
noise: torch.Tensor,
|
||
|
inpaint_mask: torch.Tensor | None,
|
||
|
):
|
||
|
# guidance_vec is ignored for schnell.
|
||
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||
|
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||
|
pred = model(
|
||
|
img=img,
|
||
|
img_ids=img_ids,
|
||
|
txt=txt,
|
||
|
txt_ids=txt_ids,
|
||
|
y=vec,
|
||
|
timesteps=t_vec,
|
||
|
guidance=guidance_vec,
|
||
|
)
|
||
|
|
||
|
img = img + (t_prev - t_curr) * pred
|
||
|
|
||
|
if inpaint_mask is not None:
|
||
|
assert init_latents is not None
|
||
|
img = merge_intermediate_latents_with_init_latents(
|
||
|
init_latents=init_latents,
|
||
|
intermediate_latents=img,
|
||
|
timestep=t_prev,
|
||
|
noise=noise,
|
||
|
inpaint_mask=inpaint_mask,
|
||
|
)
|
||
|
|
||
|
step_callback()
|
||
|
|
||
|
return img
|