diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index b6f2d6dedd..e08b4f38fd 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -17,9 +17,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.model import Flux from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack +from invokeai.backend.model_manager.config import CheckpointConfigBase from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice -from invokeai.backend.model_manager.config import CheckpointConfigBase @invocation( @@ -90,7 +90,11 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): img, img_ids = self._prepare_latent_img_patches(x) # HACK(ryand): Find a better way to determine if this is a schnell model or not. - is_schnell = "schnell" in transformer_info.config.config_path if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase) else "" + is_schnell = ( + "schnell" in transformer_info.config.config_path + if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase) + else "" + ) timesteps = get_schedule( num_steps=self.num_steps, image_seq_len=img.shape[1], @@ -161,7 +165,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): latents.to(torch.float32) img = vae.decode(latents) - img.clamp(-1, 1) + img = img.clamp(-1, 1) img = rearrange(img[0], "c h w -> h w c") img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) diff --git a/invokeai/backend/flux/sampling.py b/invokeai/backend/flux/sampling.py index 89d9d417e0..3e3c933d4e 100644 --- a/invokeai/backend/flux/sampling.py +++ b/invokeai/backend/flux/sampling.py @@ -104,9 +104,18 @@ def denoise( timesteps: list[float], guidance: float = 4.0, ): + dtype = model.txt_in.bias.dtype + + # TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller. + img = img.to(dtype=dtype) + img_ids = img_ids.to(dtype=dtype) + txt = txt.to(dtype=dtype) + txt_ids = txt_ids.to(dtype=dtype) + vec = vec.to(dtype=dtype) + # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) - for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False): + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=True): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) pred = model( img=img, diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 11a6ebcf6d..3ba933bf48 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -1,12 +1,12 @@ # Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team """Class for Flux model loading in InvokeAI.""" -import accelerate -import torch from dataclasses import fields from pathlib import Path from typing import Any, Optional +import accelerate +import torch import yaml from safetensors.torch import load_file from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer @@ -25,15 +25,15 @@ from invokeai.backend.model_manager import ( from invokeai.backend.model_manager.config import ( CheckpointConfigBase, CLIPEmbedDiffusersConfig, - MainCheckpointConfig, MainBnbQuantized4bCheckpointConfig, + MainCheckpointConfig, T5EncoderConfig, VAECheckpointConfig, ) from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader -from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 +from invokeai.backend.util.silence_warnings import SilenceWarnings app_config = get_config() @@ -109,7 +109,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader): case SubModelType.Tokenizer2: return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) case SubModelType.TextEncoder2: - return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install + return T5EncoderModel.from_pretrained( + Path(config.path) / "text_encoder_2" + ) # TODO: Fix hf subfolder install raise Exception("Only Checkpoint Flux models are currently supported.") @@ -153,7 +155,7 @@ class FluxCheckpointModel(GenericDiffusersLoader): params = FluxParams(**filtered_data) with SilenceWarnings(): - model = load_class(params).to(self._torch_dtype) + model = load_class(params) sd = load_file(model_path) model.load_state_dict(sd, strict=False, assign=True) return model