From 57168d719b43d4980ee02e43a54aa4f2ec556029 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 20 Aug 2024 13:05:31 -0400 Subject: [PATCH] Fix styling/lint --- invokeai/app/invocations/flux_text_encoder.py | 3 +- invokeai/app/invocations/model.py | 14 +- .../model_install/model_install_default.py | 2 +- invokeai/backend/bnb.py | 517 ------------------ invokeai/backend/flux/modules/layers.py | 2 +- invokeai/backend/flux/sampling.py | 4 +- invokeai/backend/load_flux_model.py | 129 ----- .../load_flux_model_bnb_llm_int8_old.py | 124 ----- invokeai/backend/model_manager/config.py | 4 +- .../model_manager/load/model_loaders/flux.py | 4 +- invokeai/backend/model_manager/probe.py | 10 +- .../load_flux_model_bnb_llm_int8.py | 2 +- 12 files changed, 27 insertions(+), 788 deletions(-) delete mode 100644 invokeai/backend/bnb.py delete mode 100644 invokeai/backend/load_flux_model.py delete mode 100644 invokeai/backend/load_flux_model_bnb_llm_int8_old.py diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 7b3f074556..a57124d2bc 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -1,5 +1,6 @@ -import torch from typing import Literal + +import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 9c9d8eb834..e104dacde0 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,8 +1,8 @@ import copy -import yaml from time import sleep from typing import Dict, List, Literal, Optional +import yaml from pydantic import BaseModel, Field from invokeai.app.invocations.baseinvocation import ( @@ -16,8 +16,14 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField from invokeai.app.services.model_records import ModelRecordChanges from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.config import CheckpointConfigBase +from invokeai.backend.model_manager.config import ( + AnyModelConfig, + BaseModelType, + CheckpointConfigBase, + ModelFormat, + ModelType, + SubModelType, +) class ModelIdentifierField(BaseModel): @@ -207,7 +213,7 @@ class FluxModelLoaderInvocation(BaseInvocation): clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0), t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder), vae=VAEField(vae=vae), - max_seq_len=flux_conf['max_seq_len'] + max_seq_len=flux_conf["max_seq_len"], ) def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField: diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 0369b86fb4..4ff4803438 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -784,7 +784,7 @@ class ModelInstallService(ModelInstallServiceBase): if subfolder: top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/" path_to_remove = top / subfolder # sdxl-turbo/vae/ - subfolder_rename = subfolder.name.replace('/', '_').replace('\\', '_') + subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_") path_to_add = Path(f"{top}_{subfolder_rename}") else: path_to_remove = Path(".") diff --git a/invokeai/backend/bnb.py b/invokeai/backend/bnb.py deleted file mode 100644 index 1022a1d1dc..0000000000 --- a/invokeai/backend/bnb.py +++ /dev/null @@ -1,517 +0,0 @@ -from typing import Any, Optional, Set, Type - -import bitsandbytes as bnb -import torch - -# The utils in this file take ideas from -# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py - - -# Patterns: -# - Quantize: -# - Initialize model on meta device -# - Replace layers -# - Load state_dict to cpu -# - Load state_dict into model -# - Quantize on GPU -# - Extract state_dict -# - Save - -# - Load: -# - Initialize model on meta device -# - Replace layers -# - Load state_dict to cpu -# - Load state_dict into model on cpu -# - Move to GPU - - -# class InvokeInt8Params(bnb.nn.Int8Params): -# """Overrides `bnb.nn.Int8Params` to add the following functionality: -# - Make it possible to load a quantized state dict without putting the weight on a "cuda" device. -# """ - -# def quantize(self, device: Optional[torch.device] = None): -# device = device or torch.device("cuda") -# if device.type != "cuda": -# raise RuntimeError(f"Int8Params quantization is only supported on CUDA devices ({device=}).") - -# # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302 -# B = self.data.contiguous().half().cuda(device) -# if self.has_fp16_weights: -# self.data = B -# else: -# # we store the 8-bit rows-major weight -# # we convert this weight to the turning/ampere weight during the first inference pass -# CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) -# del CBt -# del SCBt -# self.data = CB -# self.CB = CB -# self.SCB = SCB - - -class Invoke2Linear8bitLt(torch.nn.Linear): - """This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.""" - - def __init__( - self, - input_features: int, - output_features: int, - bias=True, - has_fp16_weights=True, - memory_efficient_backward=False, - threshold=0.0, - index=None, - device=None, - ): - """ - Initialize Linear8bitLt class. - - Args: - input_features (`int`): - Number of input features of the linear layer. - output_features (`int`): - Number of output features of the linear layer. - bias (`bool`, defaults to `True`): - Whether the linear class uses the bias term as well. - """ - super().__init__(input_features, output_features, bias, device) - assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" - self.state = bnb.MatmulLtState() - self.index = index - - self.state.threshold = threshold - self.state.has_fp16_weights = has_fp16_weights - self.state.memory_efficient_backward = memory_efficient_backward - if threshold > 0.0 and not has_fp16_weights: - self.state.use_pool = True - - self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) - self._register_load_state_dict_pre_hook(maybe_rearrange_weight) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - super()._save_to_state_dict(destination, prefix, keep_vars) - - # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data - scb_name = "SCB" - - # case 1: .cuda was called, SCB is in self.weight - param_from_weight = getattr(self.weight, scb_name) - # case 2: self.init_8bit_state was called, SCB is in self.state - param_from_state = getattr(self.state, scb_name) - # case 3: SCB is in self.state, weight layout reordered after first forward() - layout_reordered = self.state.CxB is not None - - key_name = prefix + f"{scb_name}" - format_name = prefix + "weight_format" - - if not self.state.has_fp16_weights: - if param_from_weight is not None: - destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() - destination[format_name] = torch.tensor(0, dtype=torch.uint8) - elif param_from_state is not None and not layout_reordered: - destination[key_name] = param_from_state if keep_vars else param_from_state.detach() - destination[format_name] = torch.tensor(0, dtype=torch.uint8) - elif param_from_state is not None: - destination[key_name] = param_from_state if keep_vars else param_from_state.detach() - weights_format = self.state.formatB - # At this point `weights_format` is an str - if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING: - raise ValueError(f"Unrecognized weights format {weights_format}") - - weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format] - - destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8) - - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - unexpected_copy = list(unexpected_keys) - - for key in unexpected_copy: - input_name = key[len(prefix) :] - if input_name == "SCB": - if self.weight.SCB is None: - # buffers not yet initialized, can't access them directly without quantizing first - raise RuntimeError( - "Loading a quantized checkpoint into non-quantized Linear8bitLt is " - "not supported. Please call module.cuda() before module.load_state_dict()", - ) - - input_param = state_dict[key] - self.weight.SCB.copy_(input_param) - - if self.state.SCB is not None: - self.state.SCB = self.weight.SCB - - unexpected_keys.remove(key) - - def init_8bit_state(self): - self.state.CB = self.weight.CB - self.state.SCB = self.weight.SCB - self.weight.CB = None - self.weight.SCB = None - - def forward(self, x: torch.Tensor): - self.state.is_training = self.training - if self.weight.CB is not None: - self.init_8bit_state() - - # weights are cast automatically as Int8Params, but the bias has to be cast manually - if self.bias is not None and self.bias.dtype != x.dtype: - self.bias.data = self.bias.data.to(x.dtype) - - out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - - if not self.state.has_fp16_weights: - if self.state.CB is not None and self.state.CxB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB - return out - - -class InvokeLinear8bitLt(bnb.nn.Linear8bitLt): - """Wraps `bnb.nn.Linear8bitLt` and adds the following functionality: - - enables instantiation directly on the device - - re-quantizaton when loading the state dict - """ - - def __init__( - self, *args: Any, device: Optional[torch.device] = None, threshold: float = 6.0, **kwargs: Any - ) -> None: - super().__init__(*args, device=device, threshold=threshold, **kwargs) - # If the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up - # filling the device memory with float32 weights which could lead to OOM - # if torch.tensor(0, device=device).device.type == "cuda": - # self.quantize_() - # self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_)) - # self.register_load_state_dict_post_hook(_ignore_missing_weights_hook) - - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - unexpected_copy = list(unexpected_keys) - - for key in unexpected_copy: - input_name = key[len(prefix) :] - if input_name == "SCB": - if self.weight.SCB is None: - # buffers not yet initialized, can't access them directly without quantizing first - raise RuntimeError( - "Loading a quantized checkpoint into non-quantized Linear8bitLt is " - "not supported. Please call module.cuda() before module.load_state_dict()", - ) - - input_param = state_dict[key] - self.weight.SCB.copy_(input_param) - - if self.state.SCB is not None: - self.state.SCB = self.weight.SCB - - unexpected_keys.remove(key) - - def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None: - """Inplace quantize.""" - if weight is None: - weight = self.weight.data - if weight.data.dtype == torch.int8: - # already quantized - return - assert isinstance(self.weight, bnb.nn.Int8Params) - self.weight = self.quantize(self.weight, weight, device) - - @staticmethod - def quantize( - int8params: bnb.nn.Int8Params, weight: torch.Tensor, device: Optional[torch.device] - ) -> bnb.nn.Int8Params: - device = device or torch.device("cuda") - if device.type != "cuda": - raise RuntimeError(f"Unexpected device type: {device.type}") - # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302 - B = weight.contiguous().to(device=device, dtype=torch.float16) - if int8params.has_fp16_weights: - int8params.data = B - else: - CB, CBt, SCB, SCBt, _ = bnb.functional.double_quant(B) - del CBt - del SCBt - int8params.data = CB - int8params.CB = CB - int8params.SCB = SCB - return int8params - - -# class _Linear4bit(bnb.nn.Linear4bit): -# """Wraps `bnb.nn.Linear4bit` to enable: instantiation directly on the device, re-quantizaton when loading the -# state dict, meta-device initialization, and materialization.""" - -# def __init__(self, *args: Any, device: Optional[torch.device] = None, **kwargs: Any) -> None: -# super().__init__(*args, device=device, **kwargs) -# self.weight = cast(bnb.nn.Params4bit, self.weight) # type: ignore[has-type] -# self.bias = cast(Optional[torch.nn.Parameter], self.bias) # type: ignore[has-type] -# # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up -# # filling the device memory with float32 weights which could lead to OOM -# if torch.tensor(0, device=device).device.type == "cuda": -# self.quantize_() -# self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_)) -# self.register_load_state_dict_post_hook(_ignore_missing_weights_hook) - -# def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None: -# """Inplace quantize.""" -# if weight is None: -# weight = self.weight.data -# if weight.data.dtype == torch.uint8: -# # already quantized -# return -# assert isinstance(self.weight, bnb.nn.Params4bit) -# self.weight = self.quantize(self.weight, weight, device) - -# @staticmethod -# def quantize( -# params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: Optional[torch.device] -# ) -> bnb.nn.Params4bit: -# device = device or torch.device("cuda") -# if device.type != "cuda": -# raise RuntimeError(f"Unexpected device type: {device.type}") -# # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L156-L159 -# w = weight.contiguous().to(device=device, dtype=torch.half) -# w_4bit, quant_state = bnb.functional.quantize_4bit( -# w, -# blocksize=params4bit.blocksize, -# compress_statistics=params4bit.compress_statistics, -# quant_type=params4bit.quant_type, -# ) -# return _replace_param(params4bit, w_4bit, quant_state) - -# def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self: -# if self.weight.dtype == torch.uint8: # was quantized -# # cannot init the quantized params directly -# weight = torch.empty(self.weight.quant_state.shape, device=device, dtype=torch.half) -# else: -# weight = torch.empty_like(self.weight.data, device=device) -# device = torch.device(device) -# if device.type == "cuda": # re-quantize -# self.quantize_(weight, device) -# else: -# self.weight = _replace_param(self.weight, weight) -# if self.bias is not None: -# self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device)) -# return self - - -def convert_model_to_bnb_llm_int8(model: torch.nn.Module, ignore_modules: set[str]): - linear_cls = InvokeLinear8bitLt - _convert_linear_layers(model, linear_cls, ignore_modules) - - # TODO(ryand): Is this necessary? - # set the compute dtype if necessary - # for m in model.modules(): - # if isinstance(m, bnb.nn.Linear4bit): - # m.compute_dtype = self.dtype - # m.compute_type_is_set = False - - -# class BitsandbytesPrecision(Precision): -# """Plugin for quantizing weights with `bitsandbytes `__. - -# .. warning:: This is an :ref:`experimental ` feature. - -# .. note:: -# The optimizer is not automatically replaced with ``bitsandbytes.optim.Adam8bit`` or equivalent 8-bit optimizers. - -# Args: -# mode: The quantization mode to use. -# dtype: The compute dtype to use. -# ignore_modules: The submodules whose Linear layers should not be replaced, for example. ``{"lm_head"}``. -# This might be desirable for numerical stability. The string will be checked in as a prefix, so a value like -# "transformer.blocks" will ignore all linear layers in all of the transformer blocks. -# """ - -# def __init__( -# self, -# mode: Literal["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"], -# dtype: Optional[torch.dtype] = None, -# ignore_modules: Optional[Set[str]] = None, -# ) -> None: -# if dtype is None: -# # try to be smart about the default selection -# if mode.startswith("int8"): -# dtype = torch.float16 -# else: -# dtype = ( -# torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 -# ) -# if mode.startswith("int8") and dtype is not torch.float16: -# # this limitation is mentioned in https://huggingface.co/blog/hf-bitsandbytes-integration#usage -# raise ValueError(f"{mode!r} only works with `dtype=torch.float16`, but you chose `{dtype}`") - -# globals_ = globals() -# mode_to_cls = { -# "nf4": globals_["_NF4Linear"], -# "nf4-dq": globals_["_NF4DQLinear"], -# "fp4": globals_["_FP4Linear"], -# "fp4-dq": globals_["_FP4DQLinear"], -# "int8-training": globals_["_Linear8bitLt"], -# "int8": globals_["_Int8LinearInference"], -# } -# self._linear_cls = mode_to_cls[mode] -# self.dtype = dtype -# self.ignore_modules = ignore_modules or set() - -# @override -# def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: -# # avoid naive users thinking they quantized their model -# if not any(isinstance(m, torch.nn.Linear) for m in module.modules()): -# raise TypeError( -# "You are using the bitsandbytes precision plugin, but your model has no Linear layers. This plugin" -# " won't work for your model." -# ) - -# # convert modules if they haven't been converted already -# if not any(isinstance(m, (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)) for m in module.modules()): -# # this will not quantize the model but only replace the layer classes -# _convert_layers(module, self._linear_cls, self.ignore_modules) - -# # set the compute dtype if necessary -# for m in module.modules(): -# if isinstance(m, bnb.nn.Linear4bit): -# m.compute_dtype = self.dtype -# m.compute_type_is_set = False -# return module - - -# def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_dict: OrderedDict, *_: Any) -> None: -# # There is only one key that ends with `*.weight`, the other one is the bias -# weight_key = next((name for name in state_dict if name.endswith("weight")), None) -# if weight_key is None: -# return -# # Load the weight from the state dict and re-quantize it -# weight = state_dict.pop(weight_key) -# quantize_fn(weight) - - -# def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _IncompatibleKeys) -> None: -# # since we manually loaded the weight in the `_quantize_on_load_hook` hook, we need to avoid this missing key false -# # positive -# for key in reversed(incompatible_keys.missing_keys): -# if key.endswith("weight"): -# incompatible_keys.missing_keys.remove(key) - - -def _convert_linear_layers( - module: torch.nn.Module, linear_cls: Type, ignore_modules: Set[str], prefix: str = "" -) -> None: - for name, child in module.named_children(): - fullname = f"{prefix}.{name}" if prefix else name - if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): - has_bias = child.bias is not None - # since we are going to copy over the child's data, the device doesn't matter. I chose CPU - # to avoid spiking CUDA memory even though initialization is slower - # 4bit layers support quantizing from meta-device params so this is only relevant for 8-bit - _Linear4bit = globals()["_Linear4bit"] - device = torch.device("meta" if issubclass(linear_cls, _Linear4bit) else "cpu") - replacement = linear_cls( - child.in_features, - child.out_features, - bias=has_bias, - device=device, - ) - if has_bias: - replacement.bias = _replace_param(replacement.bias, child.bias.data.clone()) - state = {"quant_state": replacement.weight.quant_state if issubclass(linear_cls, _Linear4bit) else None} - replacement.weight = _replace_param(replacement.weight, child.weight.data.clone(), **state) - module.__setattr__(name, replacement) - else: - _convert_linear_layers(child, linear_cls, ignore_modules, prefix=fullname) - - -# def _replace_linear_layers( -# model: torch.nn.Module, -# linear_layer_type: Literal["Linear8bitLt", "Linear4bit"], -# modules_to_not_convert: set[str], -# current_key_name: str | None = None, -# ): -# has_been_replaced = False -# for name, module in model.named_children(): -# if current_key_name is None: -# current_key_name = [] -# current_key_name.append(name) -# if isinstance(module, torch.nn.Linear) and name not in modules_to_not_convert: -# # Check if the current key is not in the `modules_to_not_convert` -# current_key_name_str = ".".join(current_key_name) -# proceed = True -# for key in modules_to_not_convert: -# if ( -# (key in current_key_name_str) and (key + "." in current_key_name_str) -# ) or key == current_key_name_str: -# proceed = False -# break -# if proceed: -# # Load bnb module with empty weight and replace ``nn.Linear` module -# if bnb_quantization_config.load_in_8bit: -# bnb_module = bnb.nn.Linear8bitLt( -# module.in_features, -# module.out_features, -# module.bias is not None, -# has_fp16_weights=False, -# threshold=bnb_quantization_config.llm_int8_threshold, -# ) -# elif bnb_quantization_config.load_in_4bit: -# bnb_module = bnb.nn.Linear4bit( -# module.in_features, -# module.out_features, -# module.bias is not None, -# bnb_quantization_config.bnb_4bit_compute_dtype, -# compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant, -# quant_type=bnb_quantization_config.bnb_4bit_quant_type, -# ) -# else: -# raise ValueError("load_in_8bit and load_in_4bit can't be both False") -# bnb_module.weight.data = module.weight.data -# if module.bias is not None: -# bnb_module.bias.data = module.bias.data -# bnb_module.requires_grad_(False) -# setattr(model, name, bnb_module) -# has_been_replaced = True -# if len(list(module.children())) > 0: -# _, _has_been_replaced = _replace_with_bnb_layers( -# module, bnb_quantization_config, modules_to_not_convert, current_key_name -# ) -# has_been_replaced = has_been_replaced | _has_been_replaced -# # Remove the last key for recursion -# current_key_name.pop(-1) -# return model, has_been_replaced diff --git a/invokeai/backend/flux/modules/layers.py b/invokeai/backend/flux/modules/layers.py index 4f9d515daf..d93dddba0f 100644 --- a/invokeai/backend/flux/modules/layers.py +++ b/invokeai/backend/flux/modules/layers.py @@ -5,7 +5,7 @@ import torch from einops import rearrange from torch import Tensor, nn -from ..math import attention, rope +from invokeai.backend.flux.math import attention, rope class EmbedND(nn.Module): diff --git a/invokeai/backend/flux/sampling.py b/invokeai/backend/flux/sampling.py index 5d670c3e69..675728a94b 100644 --- a/invokeai/backend/flux/sampling.py +++ b/invokeai/backend/flux/sampling.py @@ -6,8 +6,8 @@ from einops import rearrange, repeat from torch import Tensor from tqdm import tqdm -from .model import Flux -from .modules.conditioner import HFEncoder +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.modules.conditioner import HFEncoder def get_noise( diff --git a/invokeai/backend/load_flux_model.py b/invokeai/backend/load_flux_model.py deleted file mode 100644 index 9273122396..0000000000 --- a/invokeai/backend/load_flux_model.py +++ /dev/null @@ -1,129 +0,0 @@ -import json -import os -import time -from pathlib import Path -from typing import Union - -import torch -from diffusers.models.model_loading_utils import load_state_dict -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from diffusers.utils import ( - CONFIG_NAME, - SAFE_WEIGHTS_INDEX_NAME, - SAFETENSORS_WEIGHTS_NAME, - _get_checkpoint_shard_files, - is_accelerate_available, -) -from optimum.quanto import qfloat8 -from optimum.quanto.models import QuantizedDiffusersModel -from optimum.quanto.models.shared_dict import ShardedStateDict - -from invokeai.backend.requantize import requantize - - -class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel): - base_class = FluxTransformer2DModel - - @classmethod - def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]): - if cls.base_class is None: - raise ValueError("The `base_class` attribute needs to be configured.") - - if not is_accelerate_available(): - raise ValueError("Reloading a quantized diffusers model requires the accelerate library.") - from accelerate import init_empty_weights - - if os.path.isdir(model_name_or_path): - # Look for a quantization map - qmap_path = os.path.join(model_name_or_path, cls._qmap_name()) - if not os.path.exists(qmap_path): - raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?") - - # Look for original model config file. - model_config_path = os.path.join(model_name_or_path, CONFIG_NAME) - if not os.path.exists(model_config_path): - raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.") - - with open(qmap_path, "r", encoding="utf-8") as f: - qmap = json.load(f) - - with open(model_config_path, "r", encoding="utf-8") as f: - original_model_cls_name = json.load(f)["_class_name"] - configured_cls_name = cls.base_class.__name__ - if configured_cls_name != original_model_cls_name: - raise ValueError( - f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})." - ) - - # Create an empty model - config = cls.base_class.load_config(model_name_or_path) - with init_empty_weights(): - model = cls.base_class.from_config(config) - - # Look for the index of a sharded checkpoint - checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) - if os.path.exists(checkpoint_file): - # Convert the checkpoint path to a list of shards - _, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file) - # Create a mapping for the sharded safetensor files - state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"]) - else: - # Look for a single checkpoint file - checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME) - if not os.path.exists(checkpoint_file): - raise ValueError(f"No safetensor weights found in {model_name_or_path}.") - # Get state_dict from model checkpoint - state_dict = load_state_dict(checkpoint_file) - - # Requantize and load quantized weights from state_dict - requantize(model, state_dict=state_dict, quantization_map=qmap) - model.eval() - return cls(model) - else: - raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.") - - -def load_flux_transformer(path: Path) -> FluxTransformer2DModel: - # model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) - model_8bit_path = path / "quantized" - if model_8bit_path.exists(): - # The quantized model exists, load it. - # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like - # something that we should be able to make much faster. - q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path) - - # Access the underlying wrapped model. - # We access the wrapped model, even though it is private, because it simplifies the type checking by - # always returning a FluxTransformer2DModel from this function. - model = q_model._wrapped - else: - # The quantized model does not exist yet, quantize and save it. - # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on - # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it - # here. - model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) - assert isinstance(model, FluxTransformer2DModel) - - q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8) - - model_8bit_path.mkdir(parents=True, exist_ok=True) - q_model.save_pretrained(model_8bit_path) - - # (See earlier comment about accessing the wrapped model.) - model = q_model._wrapped - - assert isinstance(model, FluxTransformer2DModel) - return model - - -def main(): - start = time.time() - model = load_flux_transformer( - Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/") - ) - print(f"Time to load: {time.time() - start}s") - print("hi") - - -if __name__ == "__main__": - main() diff --git a/invokeai/backend/load_flux_model_bnb_llm_int8_old.py b/invokeai/backend/load_flux_model_bnb_llm_int8_old.py deleted file mode 100644 index f7e1471928..0000000000 --- a/invokeai/backend/load_flux_model_bnb_llm_int8_old.py +++ /dev/null @@ -1,124 +0,0 @@ -import time -from pathlib import Path - -import accelerate -import torch -from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model -from accelerate.utils.bnb import get_keys_to_not_convert -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from safetensors.torch import load_file - -from invokeai.backend.bnb import quantize_model_llm_int8 - -# Docs: -# https://huggingface.co/docs/accelerate/usage_guides/quantization -# https://huggingface.co/docs/bitsandbytes/v0.43.3/en/integrations#accelerate - - -def get_parameter_device(parameter: torch.nn.Module): - return next(parameter.parameters()).device - - -# def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], llm_int8_threshold: int = 6): -# """Apply bitsandbytes LLM.8bit() quantization to the model.""" -# model_device = get_parameter_device(model) -# if model_device.type != "meta": -# # Note: This is not strictly required, but I can't think of a good reason to quantize a model that's not on the -# # meta device, so we enforce it for now. -# raise RuntimeError("The model should be on the meta device to apply LLM.8bit() quantization.") - -# bnb_quantization_config = BnbQuantizationConfig( -# load_in_8bit=True, -# llm_int8_threshold=llm_int8_threshold, -# ) - -# with accelerate.init_empty_weights(): -# model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert) - -# return model - - -def load_flux_transformer(path: Path) -> FluxTransformer2DModel: - model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) - with accelerate.init_empty_weights(): - empty_model = FluxTransformer2DModel.from_config(model_config) - assert isinstance(empty_model, FluxTransformer2DModel) - - bnb_quantization_config = BnbQuantizationConfig( - load_in_8bit=True, - llm_int8_threshold=6, - ) - - model_8bit_path = path / "bnb_llm_int8" - if model_8bit_path.exists(): - # The quantized model already exists, load it and return it. - # Note that the model loading code is the same when loading from quantized vs original weights. The only - # difference is the weights_location. - # model = load_and_quantize_model( - # empty_model, - # weights_location=model_8bit_path, - # bnb_quantization_config=bnb_quantization_config, - # # device_map="auto", - # device_map={"": "cpu"}, - # ) - - # TODO: Handle the keys that were not quantized (get_keys_to_not_convert). - model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set()) - - # model = quantize_model_llm_int8(empty_model, set()) - - # Load sharded state dict. - files = list(path.glob("*.safetensors")) - state_dict = dict() - for file in files: - sd = load_file(file) - state_dict.update(sd) - - else: - # The quantized model does not exist yet, quantize and save it. - model = load_and_quantize_model( - empty_model, - weights_location=path, - bnb_quantization_config=bnb_quantization_config, - device_map="auto", - ) - - keys_to_not_convert = get_keys_to_not_convert(empty_model) # TODO - - model_8bit_path.mkdir(parents=True, exist_ok=True) - accl = accelerate.Accelerator() - accl.save_model(model, model_8bit_path) - - # --------------------- - - # model = quantize_model_llm_int8(empty_model, set()) - - # # Load sharded state dict. - # files = list(path.glob("*.safetensors")) - # state_dict = dict() - # for file in files: - # sd = load_file(file) - # state_dict.update(sd) - - # # Load the state dict into the model. The bitsandbytes layers know how to load from both quantized and - # # non-quantized state dicts. - # result = model.load_state_dict(state_dict, strict=True) - # model = model.to("cuda") - - # --------------------- - - assert isinstance(model, FluxTransformer2DModel) - return model - - -def main(): - start = time.time() - model = load_flux_transformer( - Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/") - ) - print(f"Time to load: {time.time() - start}s") - print("hi") - - -if __name__ == "__main__": - main() diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 5dd74dbacc..34cc993d39 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -194,7 +194,9 @@ class ModelConfigBase(BaseModel): class CheckpointConfigBase(ModelConfigBase): """Model config for checkpoint-style models.""" - format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint) + format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field( + description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint + ) config_path: str = Field(description="path to the checkpoint model config file") converted_at: Optional[float] = Field( description="When this model was last converted to diffusers", default_factory=time.time diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 5872936965..6502339a24 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -27,15 +27,15 @@ from invokeai.backend.model_manager.config import ( CLIPEmbedDiffusersConfig, MainBnbQuantized4bCheckpointConfig, MainCheckpointConfig, - T5EncoderConfig, T5Encoder8bConfig, + T5EncoderConfig, VAECheckpointConfig, ) from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 -from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel +from invokeai.backend.util.silence_warnings import SilenceWarnings app_config = get_config() diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 6ce090d651..a3364da769 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -177,10 +177,10 @@ class ModelProbe(object): fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() # additional fields needed for main and controlnet models - if ( - fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] - and fields["format"] in [ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] - ): + if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [ + ModelFormat.Checkpoint, + ModelFormat.BnbQuantizednf4b, + ]: ckpt_config_path = cls._get_checkpoint_config_path( model_path, model_type=fields["type"], @@ -326,7 +326,7 @@ class ModelProbe(object): # TODO: Decide between dev/schnell checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) state_dict = checkpoint.get("state_dict") or checkpoint - if 'guidance_in.out_layer.weight' in state_dict: + if "guidance_in.out_layer.weight" in state_dict: config_file = "flux/flux1-dev.yaml" else: config_file = "flux/flux1-schnell.yaml" diff --git a/invokeai/backend/quantization/load_flux_model_bnb_llm_int8.py b/invokeai/backend/quantization/load_flux_model_bnb_llm_int8.py index fd54210cbe..876f299add 100644 --- a/invokeai/backend/quantization/load_flux_model_bnb_llm_int8.py +++ b/invokeai/backend/quantization/load_flux_model_bnb_llm_int8.py @@ -64,7 +64,7 @@ def main(): with log_time("Load state dict into model"): # Load sharded state dict. files = list(model_path.glob("*.safetensors")) - state_dict = dict() + state_dict = {} for file in files: sd = load_file(file) state_dict.update(sd)