mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Manage quantization of models within the loader
This commit is contained in:
parent
1d8545a76c
commit
56fda669fd
@ -126,6 +126,7 @@ class FieldDescriptions:
|
|||||||
negative_cond = "Negative conditioning tensor"
|
negative_cond = "Negative conditioning tensor"
|
||||||
noise = "Noise tensor"
|
noise = "Noise tensor"
|
||||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||||
|
t5Encoder = "T5 tokenizer and text encoder"
|
||||||
unet = "UNet (scheduler, LoRAs)"
|
unet = "UNet (scheduler, LoRAs)"
|
||||||
transformer = "Transformer"
|
transformer = "Transformer"
|
||||||
vae = "VAE"
|
vae = "VAE"
|
||||||
|
@ -6,8 +6,10 @@ from optimum.quanto import qfloat8
|
|||||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
from invokeai.app.invocations.fields import InputField
|
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||||
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding, TFluxModelKeys
|
from invokeai.app.invocations.fields import InputField, FieldDescriptions, Input
|
||||||
|
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding
|
||||||
|
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||||
@ -22,9 +24,15 @@ from invokeai.backend.util.devices import TorchDevice
|
|||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class FluxTextEncoderInvocation(BaseInvocation):
|
class FluxTextEncoderInvocation(BaseInvocation):
|
||||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
clip: CLIPField = InputField(
|
||||||
use_8bit: bool = InputField(
|
title="CLIP",
|
||||||
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
t5Encoder: T5EncoderField = InputField(
|
||||||
|
title="T5EncoderField",
|
||||||
|
description=FieldDescriptions.t5Encoder,
|
||||||
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
||||||
|
|
||||||
@ -32,47 +40,43 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
# compatible with other ConditioningOutputs.
|
# compatible with other ConditioningOutputs.
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
|
||||||
|
|
||||||
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
|
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput.build(conditioning_name)
|
||||||
|
|
||||||
|
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# TODO: Determine the T5 max sequence length based on the model.
|
||||||
|
# if self.model == "flux-schnell":
|
||||||
|
max_seq_len = 256
|
||||||
|
# # elif self.model == "flux-dev":
|
||||||
|
# # max_seq_len = 512
|
||||||
|
# else:
|
||||||
|
# raise ValueError(f"Unknown model: {self.model}")
|
||||||
|
|
||||||
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
|
# Load CLIP.
|
||||||
# Determine the T5 max sequence length based on the model.
|
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||||
if self.model == "flux-schnell":
|
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||||
max_seq_len = 256
|
|
||||||
# elif self.model == "flux-dev":
|
|
||||||
# max_seq_len = 512
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown model: {self.model}")
|
|
||||||
|
|
||||||
# Load the CLIP tokenizer.
|
# Load T5.
|
||||||
clip_tokenizer_path = flux_model_dir / "tokenizer"
|
t5_tokenizer_info = context.models.load(self.t5Encoder.tokenizer)
|
||||||
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
|
t5_text_encoder_info = context.models.load(self.t5Encoder.text_encoder)
|
||||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
|
||||||
|
|
||||||
# Load the T5 tokenizer.
|
|
||||||
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
|
|
||||||
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
|
||||||
assert isinstance(t5_tokenizer, T5TokenizerFast)
|
|
||||||
|
|
||||||
clip_text_encoder_path = flux_model_dir / "text_encoder"
|
|
||||||
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
|
|
||||||
with (
|
with (
|
||||||
context.models.load_local_model(
|
clip_text_encoder_info as clip_text_encoder,
|
||||||
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
|
t5_text_encoder_info as t5_text_encoder,
|
||||||
) as clip_text_encoder,
|
clip_tokenizer_info as clip_tokenizer,
|
||||||
context.models.load_local_model(
|
t5_tokenizer_info as t5_tokenizer,
|
||||||
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
|
|
||||||
) as t5_text_encoder,
|
|
||||||
):
|
):
|
||||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||||
|
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||||
|
assert isinstance(t5_tokenizer, T5TokenizerFast)
|
||||||
|
|
||||||
pipeline = FluxPipeline(
|
pipeline = FluxPipeline(
|
||||||
scheduler=None,
|
scheduler=None,
|
||||||
vae=None,
|
vae=None,
|
||||||
@ -85,7 +89,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# prompt_embeds: T5 embeddings
|
# prompt_embeds: T5 embeddings
|
||||||
# pooled_prompt_embeds: CLIP embeddings
|
# pooled_prompt_embeds: CLIP embeddings
|
||||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
|
||||||
prompt=self.positive_prompt,
|
prompt=self.positive_prompt,
|
||||||
prompt_2=self.positive_prompt,
|
prompt_2=self.positive_prompt,
|
||||||
device=TorchDevice.choose_torch_device(),
|
device=TorchDevice.choose_torch_device(),
|
||||||
@ -95,41 +99,3 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
assert isinstance(prompt_embeds, torch.Tensor)
|
assert isinstance(prompt_embeds, torch.Tensor)
|
||||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||||
return prompt_embeds, pooled_prompt_embeds
|
return prompt_embeds, pooled_prompt_embeds
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
|
|
||||||
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
|
|
||||||
assert isinstance(model, CLIPTextModel)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
|
|
||||||
if self.use_8bit:
|
|
||||||
model_8bit_path = path / "quantized"
|
|
||||||
if model_8bit_path.exists():
|
|
||||||
# The quantized model exists, load it.
|
|
||||||
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
|
||||||
# something that we should be able to make much faster.
|
|
||||||
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
|
|
||||||
|
|
||||||
# Access the underlying wrapped model.
|
|
||||||
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
|
||||||
# always returning a T5EncoderModel from this function.
|
|
||||||
model = q_model._wrapped
|
|
||||||
else:
|
|
||||||
# The quantized model does not exist yet, quantize and save it.
|
|
||||||
# TODO(ryand): dtype?
|
|
||||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
|
||||||
assert isinstance(model, T5EncoderModel)
|
|
||||||
|
|
||||||
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
|
|
||||||
|
|
||||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
q_model.save_pretrained(model_8bit_path)
|
|
||||||
|
|
||||||
# (See earlier comment about accessing the wrapped model.)
|
|
||||||
model = q_model._wrapped
|
|
||||||
else:
|
|
||||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
|
||||||
|
|
||||||
assert isinstance(model, T5EncoderModel)
|
|
||||||
return model
|
|
||||||
|
@ -6,7 +6,7 @@ import accelerate
|
|||||||
import torch
|
import torch
|
||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||||
from optimum.quanto import qfloat8
|
from optimum.quanto import qfloat8
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
@ -52,17 +52,14 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
|
|||||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Text-to-image generation using a FLUX model."""
|
"""Text-to-image generation using a FLUX model."""
|
||||||
|
|
||||||
flux_model: ModelIdentifierField = InputField(
|
transformer: TransformerField = InputField(
|
||||||
description="The Flux model",
|
description=FieldDescriptions.unet,
|
||||||
input=Input.Any,
|
input=Input.Connection,
|
||||||
ui_type=UIType.FluxMainModel
|
title="Transformer",
|
||||||
)
|
)
|
||||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
vae: VAEField = InputField(
|
||||||
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField(
|
description=FieldDescriptions.vae,
|
||||||
default="raw", description="The type of quantization to use for the transformer model."
|
input=Input.Connection,
|
||||||
)
|
|
||||||
use_8bit: bool = InputField(
|
|
||||||
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
|
||||||
)
|
)
|
||||||
positive_text_conditioning: ConditioningField = InputField(
|
positive_text_conditioning: ConditioningField = InputField(
|
||||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||||
@ -78,13 +75,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
|
||||||
flux_transformer_path = context.models.download_and_cache_model(
|
|
||||||
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors"
|
|
||||||
)
|
|
||||||
flux_ae_path = context.models.download_and_cache_model(
|
|
||||||
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load the conditioning data.
|
# Load the conditioning data.
|
||||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||||
@ -92,56 +82,31 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
flux_conditioning = cond_data.conditionings[0]
|
flux_conditioning = cond_data.conditionings[0]
|
||||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||||
|
|
||||||
latents = self._run_diffusion(
|
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
||||||
context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds
|
image = self._run_vae_decoding(context, latents)
|
||||||
)
|
|
||||||
image = self._run_vae_decoding(context, flux_ae_path, latents)
|
|
||||||
image_dto = context.images.save(image=image)
|
image_dto = context.images.save(image=image)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
def _run_diffusion(
|
def _run_diffusion(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
flux_transformer_path: Path,
|
|
||||||
clip_embeddings: torch.Tensor,
|
clip_embeddings: torch.Tensor,
|
||||||
t5_embeddings: torch.Tensor,
|
t5_embeddings: torch.Tensor,
|
||||||
):
|
):
|
||||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
scheduler_info = context.models.load(self.transformer.scheduler)
|
||||||
|
transformer_info = context.models.load(self.transformer.transformer)
|
||||||
# Prepare input noise.
|
|
||||||
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
|
|
||||||
# CPU RNG?
|
|
||||||
x = get_noise(
|
|
||||||
num_samples=1,
|
|
||||||
height=self.height,
|
|
||||||
width=self.width,
|
|
||||||
device=TorchDevice.choose_torch_device(),
|
|
||||||
dtype=inference_dtype,
|
|
||||||
seed=self.seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
img, img_ids = self._prepare_latent_img_patches(x)
|
|
||||||
|
|
||||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
|
||||||
is_schnell = "shnell" in str(flux_transformer_path)
|
|
||||||
timesteps = get_schedule(
|
|
||||||
num_steps=self.num_steps,
|
|
||||||
image_seq_len=img.shape[1],
|
|
||||||
shift=not is_schnell,
|
|
||||||
)
|
|
||||||
|
|
||||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
|
||||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
|
||||||
|
|
||||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||||
# if the cache is not empty.
|
# if the cache is not empty.
|
||||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||||
|
|
||||||
with context.models.load_local_model(
|
with (
|
||||||
model_path=flux_transformer_path, loader=self._load_flux_transformer
|
transformer_info as transformer,
|
||||||
) as transformer:
|
scheduler_info as scheduler
|
||||||
assert isinstance(transformer, Flux)
|
):
|
||||||
|
assert isinstance(transformer, FluxTransformer2DModel)
|
||||||
|
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
|
||||||
|
|
||||||
x = denoise(
|
x = denoise(
|
||||||
model=transformer,
|
model=transformer,
|
||||||
@ -185,75 +150,25 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
def _run_vae_decoding(
|
def _run_vae_decoding(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
flux_ae_path: Path,
|
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae:
|
vae_info = context.models.load(self.vae.vae)
|
||||||
assert isinstance(vae, AutoEncoder)
|
with vae_info as vae:
|
||||||
# TODO(ryand): Test that this works with both float16 and bfloat16.
|
assert isinstance(vae, AutoencoderKL)
|
||||||
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
|
|
||||||
img = vae.decode(latents)
|
|
||||||
|
|
||||||
img.clamp(-1, 1)
|
img.clamp(-1, 1)
|
||||||
img = rearrange(img[0], "c h w -> h w c")
|
img = rearrange(img[0], "c h w -> h w c")
|
||||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||||
|
|
||||||
return img_pil
|
latents = flux_pipeline_with_vae._unpack_latents(
|
||||||
|
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
|
||||||
|
)
|
||||||
|
latents = (
|
||||||
|
latents / flux_pipeline_with_vae.vae.config.scaling_factor
|
||||||
|
) + flux_pipeline_with_vae.vae.config.shift_factor
|
||||||
|
latents = latents.to(dtype=vae.dtype)
|
||||||
|
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
|
||||||
|
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
|
||||||
|
|
||||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
assert isinstance(image, Image.Image)
|
||||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
return image
|
||||||
if self.quantization_type == "raw":
|
|
||||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
|
||||||
params = flux_configs["flux-schnell"].params
|
|
||||||
|
|
||||||
# Initialize the model on the "meta" device.
|
|
||||||
with accelerate.init_empty_weights():
|
|
||||||
model = Flux(params).to(inference_dtype)
|
|
||||||
|
|
||||||
state_dict = load_file(path)
|
|
||||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
|
||||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
|
||||||
elif self.quantization_type == "NF4":
|
|
||||||
model_path = path.parent / "bnb_nf4.safetensors"
|
|
||||||
|
|
||||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
|
||||||
params = flux_configs["flux-schnell"].params
|
|
||||||
# Initialize the model on the "meta" device.
|
|
||||||
with accelerate.init_empty_weights():
|
|
||||||
model = Flux(params)
|
|
||||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
|
||||||
# this on GPUs without bfloat16 support.
|
|
||||||
state_dict = load_file(model_path)
|
|
||||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
|
||||||
|
|
||||||
elif self.quantization_type == "llm_int8":
|
|
||||||
raise NotImplementedError("LLM int8 quantization is not yet supported.")
|
|
||||||
# model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
|
||||||
# with accelerate.init_empty_weights():
|
|
||||||
# empty_model = FluxTransformer2DModel.from_config(model_config)
|
|
||||||
# assert isinstance(empty_model, FluxTransformer2DModel)
|
|
||||||
# model_int8_path = path / "bnb_llm_int8"
|
|
||||||
# assert model_int8_path.exists()
|
|
||||||
# with accelerate.init_empty_weights():
|
|
||||||
# model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
|
|
||||||
|
|
||||||
# sd = load_file(model_int8_path / "model.safetensors")
|
|
||||||
# model.load_state_dict(sd, strict=True, assign=True)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
|
||||||
|
|
||||||
assert isinstance(model, FluxTransformer2DModel)
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _load_flux_vae(path: Path) -> AutoEncoder:
|
|
||||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
|
||||||
ae_params = flux_configs["flux1-schnell"].ae_params
|
|
||||||
with accelerate.init_empty_weights():
|
|
||||||
ae = AutoEncoder(ae_params)
|
|
||||||
|
|
||||||
state_dict = load_file(path)
|
|
||||||
ae.load_state_dict(state_dict, strict=True, assign=True)
|
|
||||||
return ae
|
|
||||||
|
@ -65,6 +65,10 @@ class TransformerField(BaseModel):
|
|||||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||||
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
|
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
|
||||||
|
|
||||||
|
class T5EncoderField(BaseModel):
|
||||||
|
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||||
|
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||||
|
|
||||||
|
|
||||||
class VAEField(BaseModel):
|
class VAEField(BaseModel):
|
||||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||||
@ -133,8 +137,8 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
|||||||
"""Flux base model loader output"""
|
"""Flux base model loader output"""
|
||||||
|
|
||||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
t5Encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder")
|
||||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@ -166,7 +170,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
return FluxModelLoaderOutput(
|
return FluxModelLoaderOutput(
|
||||||
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
|
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
|
||||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=text_encoder2),
|
||||||
vae=VAEField(vae=vae),
|
vae=VAEField(vae=vae),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -78,7 +78,12 @@ class GenericDiffusersLoader(ModelLoader):
|
|||||||
|
|
||||||
# TO DO: Add exception handling
|
# TO DO: Add exception handling
|
||||||
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
||||||
if module in ["diffusers", "transformers"]:
|
if module in [
|
||||||
|
"diffusers",
|
||||||
|
"transformers",
|
||||||
|
"invokeai.backend.quantization.fast_quantized_transformers_model",
|
||||||
|
"invokeai.backend.quantization.fast_quantized_diffusion_model",
|
||||||
|
]:
|
||||||
res_type = sys.modules[module]
|
res_type = sys.modules[module]
|
||||||
else:
|
else:
|
||||||
res_type = sys.modules["diffusers"].pipelines
|
res_type = sys.modules["diffusers"].pipelines
|
||||||
|
@ -9,7 +9,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
|
||||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||||
@ -50,6 +50,13 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
return model.calc_size()
|
return model.calc_size()
|
||||||
|
elif isinstance(
|
||||||
|
model,
|
||||||
|
(
|
||||||
|
T5TokenizerFast,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return len(model)
|
||||||
else:
|
else:
|
||||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||||
# supported model types.
|
# supported model types.
|
||||||
|
@ -12,15 +12,17 @@ from diffusers.utils import (
|
|||||||
)
|
)
|
||||||
from optimum.quanto.models import QuantizedDiffusersModel
|
from optimum.quanto.models import QuantizedDiffusersModel
|
||||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||||
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
|
|
||||||
from invokeai.backend.requantize import requantize
|
from invokeai.backend.requantize import requantize
|
||||||
|
|
||||||
|
|
||||||
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
|
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class = FluxTransformer2DModel, **kwargs):
|
||||||
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
||||||
if cls.base_class is None:
|
base_class = base_class or cls.base_class
|
||||||
|
if base_class is None:
|
||||||
raise ValueError("The `base_class` attribute needs to be configured.")
|
raise ValueError("The `base_class` attribute needs to be configured.")
|
||||||
|
|
||||||
if not is_accelerate_available():
|
if not is_accelerate_available():
|
||||||
@ -43,16 +45,16 @@ class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
|
|||||||
|
|
||||||
with open(model_config_path, "r", encoding="utf-8") as f:
|
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||||
original_model_cls_name = json.load(f)["_class_name"]
|
original_model_cls_name = json.load(f)["_class_name"]
|
||||||
configured_cls_name = cls.base_class.__name__
|
configured_cls_name = base_class.__name__
|
||||||
if configured_cls_name != original_model_cls_name:
|
if configured_cls_name != original_model_cls_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
|
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create an empty model
|
# Create an empty model
|
||||||
config = cls.base_class.load_config(model_name_or_path)
|
config = base_class.load_config(model_name_or_path)
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = cls.base_class.from_config(config)
|
model = base_class.from_config(config)
|
||||||
|
|
||||||
# Look for the index of a sharded checkpoint
|
# Look for the index of a sharded checkpoint
|
||||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
@ -72,6 +74,6 @@ class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
|
|||||||
# Requantize and load quantized weights from state_dict
|
# Requantize and load quantized weights from state_dict
|
||||||
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||||
model.eval()
|
model.eval()
|
||||||
return cls(model)
|
return cls(model)._wrapped
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from optimum.quanto.models import QuantizedTransformersModel
|
from optimum.quanto.models import QuantizedTransformersModel
|
||||||
@ -7,15 +8,17 @@ from optimum.quanto.models.shared_dict import ShardedStateDict
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
|
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
|
||||||
|
from transformers.models.auto import AutoModelForTextEncoding
|
||||||
|
|
||||||
from invokeai.backend.requantize import requantize
|
from invokeai.backend.requantize import requantize
|
||||||
|
|
||||||
|
|
||||||
class FastQuantizedTransformersModel(QuantizedTransformersModel):
|
class FastQuantizedTransformersModel(QuantizedTransformersModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], auto_class = AutoModelForTextEncoding, **kwargs):
|
||||||
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
||||||
if cls.auto_class is None:
|
auto_class = auto_class or cls.auto_class
|
||||||
|
if auto_class is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
|
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
|
||||||
)
|
)
|
||||||
@ -33,7 +36,7 @@ class FastQuantizedTransformersModel(QuantizedTransformersModel):
|
|||||||
# Create an empty model
|
# Create an empty model
|
||||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = cls.auto_class.from_config(config)
|
model = auto_class.from_config(config)
|
||||||
# Look for the index of a sharded checkpoint
|
# Look for the index of a sharded checkpoint
|
||||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
if os.path.exists(checkpoint_file):
|
if os.path.exists(checkpoint_file):
|
||||||
@ -56,6 +59,6 @@ class FastQuantizedTransformersModel(QuantizedTransformersModel):
|
|||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
# Set model in evaluation mode as it is done in transformers
|
# Set model in evaluation mode as it is done in transformers
|
||||||
model.eval()
|
model.eval()
|
||||||
return cls(model)
|
return cls(model)._wrapped
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||||
|
@ -5697,15 +5697,15 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
transformer: components["schemas"]["TransformerField"];
|
transformer: components["schemas"]["TransformerField"];
|
||||||
/**
|
/**
|
||||||
* CLIP 1
|
* CLIP
|
||||||
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||||
*/
|
*/
|
||||||
clip: components["schemas"]["CLIPField"];
|
clip: components["schemas"]["CLIPField"];
|
||||||
/**
|
/**
|
||||||
* CLIP 2
|
* T5 Encoder
|
||||||
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
* @description T5 tokenizer and text encoder
|
||||||
*/
|
*/
|
||||||
clip2: components["schemas"]["CLIPField"];
|
t5Encoder: components["schemas"]["T5EncoderField"];
|
||||||
/**
|
/**
|
||||||
* VAE
|
* VAE
|
||||||
* @description VAE
|
* @description VAE
|
||||||
@ -5739,19 +5739,17 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
use_cache?: boolean;
|
use_cache?: boolean;
|
||||||
/**
|
/**
|
||||||
* Model
|
* CLIP
|
||||||
* @description The FLUX model to use for text-to-image generation.
|
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||||
* @default null
|
* @default null
|
||||||
* @constant
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
*/
|
||||||
model?: "flux-schnell";
|
clip?: components["schemas"]["CLIPField"];
|
||||||
/**
|
/**
|
||||||
* Use 8Bit
|
* T5EncoderField
|
||||||
* @description Whether to quantize the transformer model to 8-bit precision.
|
* @description T5 tokenizer and text encoder
|
||||||
* @default false
|
* @default null
|
||||||
*/
|
*/
|
||||||
use_8bit?: boolean;
|
t5Encoder?: components["schemas"]["T5EncoderField"];
|
||||||
/**
|
/**
|
||||||
* Positive Prompt
|
* Positive Prompt
|
||||||
* @description Positive prompt for text-to-image generation.
|
* @description Positive prompt for text-to-image generation.
|
||||||
@ -5799,31 +5797,16 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
use_cache?: boolean;
|
use_cache?: boolean;
|
||||||
/**
|
/**
|
||||||
* @description The Flux model
|
* Transformer
|
||||||
|
* @description UNet (scheduler, LoRAs)
|
||||||
* @default null
|
* @default null
|
||||||
*/
|
*/
|
||||||
flux_model?: components["schemas"]["ModelIdentifierField"];
|
transformer?: components["schemas"]["TransformerField"];
|
||||||
/**
|
/**
|
||||||
* Model
|
* @description VAE
|
||||||
* @description The FLUX model to use for text-to-image generation.
|
|
||||||
* @default null
|
* @default null
|
||||||
* @constant
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
*/
|
||||||
model?: "flux-schnell";
|
vae?: components["schemas"]["VAEField"];
|
||||||
/**
|
|
||||||
* Quantization Type
|
|
||||||
* @description The type of quantization to use for the transformer model.
|
|
||||||
* @default raw
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
quantization_type?: "raw" | "NF4" | "llm_int8";
|
|
||||||
/**
|
|
||||||
* Use 8Bit
|
|
||||||
* @description Whether to quantize the transformer model to 8-bit precision.
|
|
||||||
* @default false
|
|
||||||
*/
|
|
||||||
use_8bit?: boolean;
|
|
||||||
/**
|
/**
|
||||||
* @description Positive conditioning tensor
|
* @description Positive conditioning tensor
|
||||||
* @default null
|
* @default null
|
||||||
@ -14268,6 +14251,13 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
type: "t2i_adapter_output";
|
type: "t2i_adapter_output";
|
||||||
};
|
};
|
||||||
|
/** T5EncoderField */
|
||||||
|
T5EncoderField: {
|
||||||
|
/** @description Info to load tokenizer submodel */
|
||||||
|
tokenizer: components["schemas"]["ModelIdentifierField"];
|
||||||
|
/** @description Info to load text_encoder submodel */
|
||||||
|
text_encoder: components["schemas"]["ModelIdentifierField"];
|
||||||
|
};
|
||||||
/** TBLR */
|
/** TBLR */
|
||||||
TBLR: {
|
TBLR: {
|
||||||
/** Top */
|
/** Top */
|
||||||
|
Loading…
Reference in New Issue
Block a user