mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy variable management and dtype handling in FluxTextToImageInvocation.
This commit is contained in:
parent
5e8cf9fb6a
commit
4e4b6c6dbc
@ -58,13 +58,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# Load the conditioning data.
|
latents = self._run_diffusion(context)
|
||||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
|
||||||
assert len(cond_data.conditionings) == 1
|
|
||||||
flux_conditioning = cond_data.conditionings[0]
|
|
||||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
|
||||||
|
|
||||||
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
|
||||||
image = self._run_vae_decoding(context, latents)
|
image = self._run_vae_decoding(context, latents)
|
||||||
image_dto = context.images.save(image=image)
|
image_dto = context.images.save(image=image)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
@ -72,12 +66,20 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
def _run_diffusion(
|
def _run_diffusion(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
clip_embeddings: torch.Tensor,
|
|
||||||
t5_embeddings: torch.Tensor,
|
|
||||||
):
|
):
|
||||||
transformer_info = context.models.load(self.transformer.transformer)
|
|
||||||
inference_dtype = torch.bfloat16
|
inference_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# Load the conditioning data.
|
||||||
|
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||||
|
assert len(cond_data.conditionings) == 1
|
||||||
|
flux_conditioning = cond_data.conditionings[0]
|
||||||
|
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||||
|
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||||
|
t5_embeddings = flux_conditioning.t5_embeds
|
||||||
|
clip_embeddings = flux_conditioning.clip_embeds
|
||||||
|
|
||||||
|
transformer_info = context.models.load(self.transformer.transformer)
|
||||||
|
|
||||||
# Prepare input noise.
|
# Prepare input noise.
|
||||||
x = get_noise(
|
x = get_noise(
|
||||||
num_samples=1,
|
num_samples=1,
|
||||||
@ -88,13 +90,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
img, img_ids = prepare_latent_img_patches(x)
|
x, img_ids = prepare_latent_img_patches(x)
|
||||||
|
|
||||||
is_schnell = "schnell" in transformer_info.config.config_path
|
is_schnell = "schnell" in transformer_info.config.config_path
|
||||||
|
|
||||||
timesteps = get_schedule(
|
timesteps = get_schedule(
|
||||||
num_steps=self.num_steps,
|
num_steps=self.num_steps,
|
||||||
image_seq_len=img.shape[1],
|
image_seq_len=x.shape[1],
|
||||||
shift=not is_schnell,
|
shift=not is_schnell,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -135,7 +137,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
x = denoise(
|
x = denoise(
|
||||||
model=transformer,
|
model=transformer,
|
||||||
img=img,
|
img=x,
|
||||||
img_ids=img_ids,
|
img_ids=img_ids,
|
||||||
txt=t5_embeddings,
|
txt=t5_embeddings,
|
||||||
txt_ids=txt_ids,
|
txt_ids=txt_ids,
|
||||||
|
@ -111,16 +111,7 @@ def denoise(
|
|||||||
step_callback: Callable[[], None],
|
step_callback: Callable[[], None],
|
||||||
guidance: float = 4.0,
|
guidance: float = 4.0,
|
||||||
):
|
):
|
||||||
dtype = model.txt_in.bias.dtype
|
# guidance_vec is ignored for schnell.
|
||||||
|
|
||||||
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
|
|
||||||
img = img.to(dtype=dtype)
|
|
||||||
img_ids = img_ids.to(dtype=dtype)
|
|
||||||
txt = txt.to(dtype=dtype)
|
|
||||||
txt_ids = txt_ids.to(dtype=dtype)
|
|
||||||
vec = vec.to(dtype=dtype)
|
|
||||||
|
|
||||||
# this is ignored for schnell
|
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor,
|
|||||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||||
|
|
||||||
# Generate patch position ids.
|
# Generate patch position ids.
|
||||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
|
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
return img, img_ids
|
return img, img_ids
|
||||||
|
@ -43,6 +43,11 @@ class FLUXConditioningInfo:
|
|||||||
clip_embeds: torch.Tensor
|
clip_embeds: torch.Tensor
|
||||||
t5_embeds: torch.Tensor
|
t5_embeds: torch.Tensor
|
||||||
|
|
||||||
|
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||||
|
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
|
||||||
|
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningFieldData:
|
class ConditioningFieldData:
|
||||||
|
Loading…
Reference in New Issue
Block a user