This commit is contained in:
Brandon Rising 2024-08-15 10:27:42 -04:00 committed by Brandon
parent 9ed53af520
commit 2d9042fb93
4 changed files with 5 additions and 17 deletions

View File

@ -1,14 +1,9 @@
from pathlib import Path
import torch import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from optimum.quanto import qfloat8
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.fields import InputField, FieldDescriptions, Input
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding
from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
@ -40,7 +35,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
# compatible with other ConditioningOutputs. # compatible with other ConditioningOutputs.
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
t5_embeddings, clip_embeddings = self._encode_prompt(context) t5_embeddings, clip_embeddings = self._encode_prompt(context)
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)] conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
@ -48,7 +42,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name) return ConditioningOutput.build(conditioning_name)
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]: def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# TODO: Determine the T5 max sequence length based on the model. # TODO: Determine the T5 max sequence length based on the model.
# if self.model == "flux-schnell": # if self.model == "flux-schnell":

View File

@ -1,13 +1,9 @@
from pathlib import Path
from typing import Literal from typing import Literal
from pydantic import Field
import accelerate import accelerate
import torch import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from invokeai.app.invocations.model import TransformerField, VAEField
from optimum.quanto import qfloat8
from PIL import Image from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers.models.auto import AutoModelForTextEncoding from transformers.models.auto import AutoModelForTextEncoding
@ -20,8 +16,8 @@ from invokeai.app.invocations.fields import (
InputField, InputField,
WithBoard, WithBoard,
WithMetadata, WithMetadata,
UIType,
) )
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@ -75,7 +71,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data. # Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1 assert len(cond_data.conditionings) == 1

View File

@ -3,6 +3,7 @@ import os
from typing import Union from typing import Union
from diffusers.models.model_loading_utils import load_state_dict from diffusers.models.model_loading_utils import load_state_dict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.utils import ( from diffusers.utils import (
CONFIG_NAME, CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
@ -12,7 +13,6 @@ from diffusers.utils import (
) )
from optimum.quanto.models import QuantizedDiffusersModel from optimum.quanto.models import QuantizedDiffusersModel
from optimum.quanto.models.shared_dict import ShardedStateDict from optimum.quanto.models.shared_dict import ShardedStateDict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from invokeai.backend.requantize import requantize from invokeai.backend.requantize import requantize

View File

@ -1,14 +1,13 @@
import json import json
import os import os
import torch
from typing import Union from typing import Union
from optimum.quanto.models import QuantizedTransformersModel from optimum.quanto.models import QuantizedTransformersModel
from optimum.quanto.models.shared_dict import ShardedStateDict from optimum.quanto.models.shared_dict import ShardedStateDict
from transformers import AutoConfig from transformers import AutoConfig
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
from transformers.models.auto import AutoModelForTextEncoding from transformers.models.auto import AutoModelForTextEncoding
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
from invokeai.backend.requantize import requantize from invokeai.backend.requantize import requantize