From be6cb2c07ca7478b58530f6c526ffe3bb41667d5 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Mon, 19 Aug 2024 13:12:38 -0400 Subject: [PATCH] Working inference node with quantized bnb nf4 checkpoint --- .../app/invocations/flux_text_to_image.py | 14 +++-- .../model_manager/load/model_loaders/flux.py | 62 +++++++++++++++++-- 2 files changed, 65 insertions(+), 11 deletions(-) diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 5e652b1375..fd7f53df10 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -89,7 +89,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): img, img_ids = self._prepare_latent_img_patches(x) # HACK(ryand): Find a better way to determine if this is a schnell model or not. - is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else "" + is_schnell = "schnell" in transformer_info.config.path if transformer_info.config else "" timesteps = get_schedule( num_steps=self.num_steps, image_seq_len=img.shape[1], @@ -139,9 +139,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): img = repeat(img, "1 ... -> bs ...", bs=bs) # Generate patch position ids. - img_ids = torch.zeros(h // 2, w // 2, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) return img, img_ids @@ -155,8 +155,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): with vae_info as vae: assert isinstance(vae, AutoEncoder) # TODO(ryand): Test that this works with both float16 and bfloat16. - with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()): - img = vae.decode(latents) + # with torch.autocast(device_type=latents.device.type, dtype=torch.float32): + vae.to(torch.float32) + latents.to(torch.float32) + img = vae.decode(latents) img.clamp(-1, 1) img = rearrange(img[0], "c h w -> h w c") diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 78ecfccfa3..5ef7f460ce 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -1,6 +1,8 @@ # Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team """Class for Flux model loading in InvokeAI.""" +import accelerate +import torch from dataclasses import fields from pathlib import Path from typing import Any, Optional @@ -24,6 +26,7 @@ from invokeai.backend.model_manager.config import ( CheckpointConfigBase, CLIPEmbedDiffusersConfig, MainCheckpointConfig, + MainBnbQuantized4bCheckpointConfig, T5EncoderConfig, VAECheckpointConfig, ) @@ -31,6 +34,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.silence_warnings import SilenceWarnings +from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 app_config = get_config() @@ -62,7 +66,7 @@ class FluxVAELoader(GenericDiffusersLoader): with SilenceWarnings(): model = load_class(params).to(self._torch_dtype) # load_sft doesn't support torch.device - sd = load_file(model_path, device=str(TorchDevice.choose_torch_device())) + sd = load_file(model_path) model.load_state_dict(sd, strict=False, assign=True) return model @@ -105,9 +109,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader): match submodel_type: case SubModelType.Tokenizer2: - return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512) + return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) case SubModelType.TextEncoder2: - return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer") + return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install raise Exception("Only Checkpoint Flux models are currently supported.") @@ -152,7 +156,55 @@ class FluxCheckpointModel(GenericDiffusersLoader): with SilenceWarnings(): model = load_class(params).to(self._torch_dtype) - # load_sft doesn't support torch.device - sd = load_file(model_path, device=str(TorchDevice.choose_torch_device())) + sd = load_file(model_path) + model.load_state_dict(sd, strict=False, assign=True) + return model + + +@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b) +class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader): + """Class to load main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, CheckpointConfigBase): + raise Exception("Only Checkpoint Flux models are currently supported.") + legacy_config_path = app_config.legacy_conf_path / config.config_path + config_path = legacy_config_path.as_posix() + with open(config_path, "r") as stream: + try: + flux_conf = yaml.safe_load(stream) + except: + raise + + match submodel_type: + case SubModelType.Transformer: + return self._load_from_singlefile(config, flux_conf) + + raise Exception("Only Checkpoint Flux models are currently supported.") + + def _load_from_singlefile( + self, + config: AnyModelConfig, + flux_conf: Any, + ) -> AnyModel: + assert isinstance(config, MainBnbQuantized4bCheckpointConfig) + load_class = Flux + params = None + model_path = Path(config.path) + dataclass_fields = {f.name for f in fields(FluxParams)} + filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields} + params = FluxParams(**filtered_data) + + with SilenceWarnings(): + with accelerate.init_empty_weights(): + model = load_class(params) + model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) + # TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle + # this on GPUs without bfloat16 support. + sd = load_file(model_path) model.load_state_dict(sd, strict=False, assign=True) return model