Fix FLUX output image clamping. And a few other minor fixes to make inference work with the full bfloat16 FLUX transformer model.

This commit is contained in:
Ryan Dick
2024-08-20 14:39:33 +00:00
committed by Brandon
parent a63f842a13
commit 0c5e11f521
3 changed files with 25 additions and 10 deletions

View File

@ -104,9 +104,18 @@ def denoise(
timesteps: list[float],
guidance: float = 4.0,
):
dtype = model.txt_in.bias.dtype
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
img = img.to(dtype=dtype)
img_ids = img_ids.to(dtype=dtype)
txt = txt.to(dtype=dtype)
txt_ids = txt_ids.to(dtype=dtype)
vec = vec.to(dtype=dtype)
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False):
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=True):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,

View File

@ -1,12 +1,12 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""
import accelerate
import torch
from dataclasses import fields
from pathlib import Path
from typing import Any, Optional
import accelerate
import torch
import yaml
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
@ -25,15 +25,15 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
MainCheckpointConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
T5EncoderConfig,
VAECheckpointConfig,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
from invokeai.backend.util.silence_warnings import SilenceWarnings
app_config = get_config()
@ -109,7 +109,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install
return T5EncoderModel.from_pretrained(
Path(config.path) / "text_encoder_2"
) # TODO: Fix hf subfolder install
raise Exception("Only Checkpoint Flux models are currently supported.")
@ -153,7 +155,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
params = FluxParams(**filtered_data)
with SilenceWarnings():
model = load_class(params).to(self._torch_dtype)
model = load_class(params)
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
return model