mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run Ruff
This commit is contained in:
parent
46d5107ff1
commit
46b6314482
@ -1,14 +1,9 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from optimum.quanto import qfloat8
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.fields import InputField, FieldDescriptions, Input
|
||||
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@ -40,7 +35,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
# compatible with other ConditioningOutputs.
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
@ -48,7 +42,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# TODO: Determine the T5 max sequence length based on the model.
|
||||
# if self.model == "flux-schnell":
|
||||
|
@ -1,13 +1,9 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from optimum.quanto import qfloat8
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from transformers.models.auto import AutoModelForTextEncoding
|
||||
@ -20,8 +16,8 @@ from invokeai.app.invocations.fields import (
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
@ -75,7 +71,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
from typing import Union
|
||||
|
||||
from diffusers.models.model_loading_utils import load_state_dict
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from diffusers.utils import (
|
||||
CONFIG_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@ -12,7 +13,6 @@ from diffusers.utils import (
|
||||
)
|
||||
from optimum.quanto.models import QuantizedDiffusersModel
|
||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
|
||||
from invokeai.backend.requantize import requantize
|
||||
|
||||
|
@ -1,14 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
import torch
|
||||
from typing import Union
|
||||
|
||||
from optimum.quanto.models import QuantizedTransformersModel
|
||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||
from transformers import AutoConfig
|
||||
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
|
||||
from transformers.models.auto import AutoModelForTextEncoding
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
|
||||
|
||||
from invokeai.backend.requantize import requantize
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user