mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
46 lines
1.2 KiB
Python
46 lines
1.2 KiB
Python
from typing import Callable
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
|
from invokeai.backend.flux.model import Flux
|
|
|
|
|
|
def denoise(
|
|
model: Flux,
|
|
# model input
|
|
img: torch.Tensor,
|
|
img_ids: torch.Tensor,
|
|
txt: torch.Tensor,
|
|
txt_ids: torch.Tensor,
|
|
vec: torch.Tensor,
|
|
# sampling parameters
|
|
timesteps: list[float],
|
|
step_callback: Callable[[], None],
|
|
guidance: float,
|
|
inpaint_extension: InpaintExtension | None,
|
|
):
|
|
# guidance_vec is ignored for schnell.
|
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
|
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
|
pred = model(
|
|
img=img,
|
|
img_ids=img_ids,
|
|
txt=txt,
|
|
txt_ids=txt_ids,
|
|
y=vec,
|
|
timesteps=t_vec,
|
|
guidance=guidance_vec,
|
|
)
|
|
|
|
img = img + (t_prev - t_curr) * pred
|
|
|
|
if inpaint_extension is not None:
|
|
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
|
|
|
step_callback()
|
|
|
|
return img
|