mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
LLM.int8() quantization is working, but still some rough edges to solve.
This commit is contained in:
@ -21,6 +21,7 @@ from invokeai.app.invocations.fields import (
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||
@ -49,6 +50,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Text-to-image generation using a FLUX model."""
|
||||
|
||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField(
|
||||
default="raw", description="The type of quantization to use for the transformer model."
|
||||
)
|
||||
use_8bit: bool = InputField(
|
||||
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
||||
)
|
||||
@ -162,52 +166,37 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return image
|
||||
|
||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||
if self.use_8bit:
|
||||
if self.quantization_type == "raw":
|
||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
elif self.quantization_type == "NF4":
|
||||
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
|
||||
model_nf4_path = path / "bnb_nf4"
|
||||
if model_nf4_path.exists():
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
assert model_nf4_path.exists()
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
|
||||
# model.to_empty(device="cpu")
|
||||
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
||||
# this on GPUs without bfloat16 support.
|
||||
sd = load_file(model_nf4_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
# model = model.to("cuda")
|
||||
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
||||
# this on GPUs without bfloat16 support.
|
||||
sd = load_file(model_nf4_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
elif self.quantization_type == "llm_int8":
|
||||
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
model_int8_path = path / "bnb_llm_int8"
|
||||
assert model_int8_path.exists()
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
|
||||
|
||||
# model_8bit_path = path / "quantized"
|
||||
# if model_8bit_path.exists():
|
||||
# # The quantized model exists, load it.
|
||||
# # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||
# # something that we should be able to make much faster.
|
||||
# q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
|
||||
|
||||
# # Access the underlying wrapped model.
|
||||
# # We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||
# # always returning a FluxTransformer2DModel from this function.
|
||||
# model = q_model._wrapped
|
||||
# else:
|
||||
# # The quantized model does not exist yet, quantize and save it.
|
||||
# # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
||||
# # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
|
||||
# # here.
|
||||
# model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
# assert isinstance(model, FluxTransformer2DModel)
|
||||
|
||||
# q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
|
||||
|
||||
# model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
# q_model.save_pretrained(model_8bit_path)
|
||||
|
||||
# # (See earlier comment about accessing the wrapped model.)
|
||||
# model = q_model._wrapped
|
||||
sd = load_file(model_int8_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
else:
|
||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
||||
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
return model
|
||||
|
Reference in New Issue
Block a user