Fix FLUX output image clamping. And a few other minor fixes to make inference work with the full bfloat16 FLUX transformer model.

This commit is contained in:
Ryan Dick 2024-08-20 14:39:33 +00:00 committed by Brandon
parent a63f842a13
commit 0c5e11f521
3 changed files with 25 additions and 10 deletions

View File

@ -17,9 +17,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.model_manager.config import CheckpointConfigBase
@invocation(
@ -90,7 +90,11 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
img, img_ids = self._prepare_latent_img_patches(x)
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
is_schnell = "schnell" in transformer_info.config.config_path if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase) else ""
is_schnell = (
"schnell" in transformer_info.config.config_path
if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase)
else ""
)
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
@ -161,7 +165,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
latents.to(torch.float32)
img = vae.decode(latents)
img.clamp(-1, 1)
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())

View File

@ -104,9 +104,18 @@ def denoise(
timesteps: list[float],
guidance: float = 4.0,
):
dtype = model.txt_in.bias.dtype
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
img = img.to(dtype=dtype)
img_ids = img_ids.to(dtype=dtype)
txt = txt.to(dtype=dtype)
txt_ids = txt_ids.to(dtype=dtype)
vec = vec.to(dtype=dtype)
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False):
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=True):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,

View File

@ -1,12 +1,12 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""
import accelerate
import torch
from dataclasses import fields
from pathlib import Path
from typing import Any, Optional
import accelerate
import torch
import yaml
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
@ -25,15 +25,15 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
MainCheckpointConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
T5EncoderConfig,
VAECheckpointConfig,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
from invokeai.backend.util.silence_warnings import SilenceWarnings
app_config = get_config()
@ -109,7 +109,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install
return T5EncoderModel.from_pretrained(
Path(config.path) / "text_encoder_2"
) # TODO: Fix hf subfolder install
raise Exception("Only Checkpoint Flux models are currently supported.")
@ -153,7 +155,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
params = FluxParams(**filtered_data)
with SilenceWarnings():
model = load_class(params).to(self._torch_dtype)
model = load_class(params)
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
return model