Make float16 inference work with FLUX on 24GB GPU.

This commit is contained in:
Ryan Dick 2024-08-08 18:12:04 -04:00 committed by Brandon
parent 5870742bb9
commit 3cf0365a35

View File

@ -147,6 +147,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
transformer=transformer,
)
t5_embeddings = t5_embeddings.to(dtype=transformer.dtype)
clip_embeddings = clip_embeddings.to(dtype=transformer.dtype)
latents = flux_pipeline_with_transformer(
height=self.height,
width=self.width,