mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run ruff, setup initial text to image node
This commit is contained in:
parent
436f18ff55
commit
1bd90e0fd4
@ -1,8 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
from einops import repeat
|
|
||||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
@ -10,9 +6,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
|||||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
|
@ -1,12 +1,6 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
import accelerate
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
from einops import rearrange, repeat
|
||||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
|
||||||
from transformers.models.auto import AutoModelForTextEncoding
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
@ -20,23 +14,12 @@ from invokeai.app.invocations.fields import (
|
|||||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
TFluxModelKeys = Literal["flux-schnell"]
|
|
||||||
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
|
||||||
|
|
||||||
|
|
||||||
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
|
|
||||||
base_class = FluxTransformer2DModel
|
|
||||||
|
|
||||||
|
|
||||||
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
|
|
||||||
auto_class = AutoModelForTextEncoding
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"flux_text_to_image",
|
"flux_text_to_image",
|
||||||
@ -78,7 +61,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||||
|
|
||||||
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
||||||
image = self._run_vae_decoding(context, latents)
|
image = self._run_vae_decoding(context, flux_ae_path, latents)
|
||||||
image_dto = context.images.save(image=image)
|
image_dto = context.images.save(image=image)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
@ -89,14 +72,40 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
t5_embeddings: torch.Tensor,
|
t5_embeddings: torch.Tensor,
|
||||||
):
|
):
|
||||||
transformer_info = context.models.load(self.transformer.transformer)
|
transformer_info = context.models.load(self.transformer.transformer)
|
||||||
|
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
|
# Prepare input noise.
|
||||||
|
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
|
||||||
|
# CPU RNG?
|
||||||
|
x = get_noise(
|
||||||
|
num_samples=1,
|
||||||
|
height=self.height,
|
||||||
|
width=self.width,
|
||||||
|
device=TorchDevice.choose_torch_device(),
|
||||||
|
dtype=inference_dtype,
|
||||||
|
seed=self.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
img, img_ids = self._prepare_latent_img_patches(x)
|
||||||
|
|
||||||
|
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||||
|
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
|
||||||
|
timesteps = get_schedule(
|
||||||
|
num_steps=self.num_steps,
|
||||||
|
image_seq_len=img.shape[1],
|
||||||
|
shift=not is_schnell,
|
||||||
|
)
|
||||||
|
|
||||||
|
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||||
|
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||||
|
|
||||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||||
# if the cache is not empty.
|
# if the cache is not empty.
|
||||||
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||||
|
|
||||||
with transformer_info as transformer:
|
with transformer_info as transformer:
|
||||||
assert isinstance(transformer, FluxTransformer2DModel)
|
assert isinstance(transformer, Flux)
|
||||||
|
|
||||||
x = denoise(
|
x = denoise(
|
||||||
model=transformer,
|
model=transformer,
|
||||||
@ -144,21 +153,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(self.vae.vae)
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
assert isinstance(vae, AutoencoderKL)
|
assert isinstance(vae, AutoEncoder)
|
||||||
|
# TODO(ryand): Test that this works with both float16 and bfloat16.
|
||||||
|
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
|
||||||
|
img = vae.decode(latents)
|
||||||
|
|
||||||
img.clamp(-1, 1)
|
img.clamp(-1, 1)
|
||||||
img = rearrange(img[0], "c h w -> h w c")
|
img = rearrange(img[0], "c h w -> h w c")
|
||||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||||
|
|
||||||
latents = flux_pipeline_with_vae._unpack_latents(
|
return img_pil
|
||||||
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
|
|
||||||
)
|
|
||||||
latents = (
|
|
||||||
latents / flux_pipeline_with_vae.vae.config.scaling_factor
|
|
||||||
) + flux_pipeline_with_vae.vae.config.shift_factor
|
|
||||||
latents = latents.to(dtype=vae.dtype)
|
|
||||||
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
|
|
||||||
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
|
|
||||||
|
|
||||||
assert isinstance(image, Image.Image)
|
|
||||||
return image
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import List, Optional, Literal, Dict
|
from typing import Dict, List, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -12,10 +12,10 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||||
|
from invokeai.app.services.model_records import ModelRecordChanges
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
from invokeai.app.services.model_records import ModelRecordChanges
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType, ModelFormat
|
|
||||||
|
|
||||||
|
|
||||||
class ModelIdentifierField(BaseModel):
|
class ModelIdentifierField(BaseModel):
|
||||||
@ -132,31 +132,22 @@ class ModelIdentifierInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return ModelIdentifierOutput(model=self.model)
|
return ModelIdentifierOutput(model=self.model)
|
||||||
|
|
||||||
T5_ENCODER_OPTIONS = Literal["base", "16b_quantized", "8b_quantized"]
|
|
||||||
|
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
|
||||||
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
|
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
|
||||||
"base": {
|
"base": {
|
||||||
"text_encoder_repo": "black-forest-labs/FLUX.1-schnell::text_encoder_2",
|
"repo": "invokeai/flux_dev::t5_xxl_encoder/base",
|
||||||
"tokenizer_repo": "black-forest-labs/FLUX.1-schnell::tokenizer_2",
|
"name": "t5_base_encoder",
|
||||||
"text_encoder_name": "FLUX.1-schnell_text_encoder_2",
|
|
||||||
"tokenizer_name": "FLUX.1-schnell_tokenizer_2",
|
|
||||||
"format": ModelFormat.T5Encoder,
|
"format": ModelFormat.T5Encoder,
|
||||||
},
|
},
|
||||||
"8b_quantized": {
|
"8b_quantized": {
|
||||||
"text_encoder_repo": "hf_repo1",
|
"repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized",
|
||||||
"tokenizer_repo": "hf_repo1",
|
"name": "t5_8b_quantized_encoder",
|
||||||
"text_encoder_name": "hf_repo1",
|
"format": ModelFormat.T5Encoder,
|
||||||
"tokenizer_name": "hf_repo1",
|
|
||||||
"format": ModelFormat.T5Encoder8b,
|
|
||||||
},
|
|
||||||
"4b_quantized": {
|
|
||||||
"text_encoder_repo": "hf_repo2",
|
|
||||||
"tokenizer_repo": "hf_repo2",
|
|
||||||
"text_encoder_name": "hf_repo2",
|
|
||||||
"tokenizer_name": "hf_repo2",
|
|
||||||
"format": ModelFormat.T5Encoder8b,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("flux_model_loader_output")
|
@invocation_output("flux_model_loader_output")
|
||||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""Flux base model loader output"""
|
"""Flux base model loader output"""
|
||||||
@ -176,7 +167,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
ui_type=UIType.FluxMainModel,
|
ui_type=UIType.FluxMainModel,
|
||||||
input=Input.Direct,
|
input=Input.Direct,
|
||||||
)
|
)
|
||||||
|
|
||||||
t5_encoder: T5_ENCODER_OPTIONS = InputField(description="The T5 Encoder model to use.")
|
t5_encoder: T5_ENCODER_OPTIONS = InputField(description="The T5 Encoder model to use.")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||||
@ -189,7 +180,15 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
||||||
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
|
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
|
||||||
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
|
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
|
||||||
vae = self._install_model(context, SubModelType.VAE, "FLUX.1-schnell_ae", "black-forest-labs/FLUX.1-schnell::ae.safetensors", ModelFormat.Checkpoint, ModelType.VAE, BaseModelType.Flux)
|
vae = self._install_model(
|
||||||
|
context,
|
||||||
|
SubModelType.VAE,
|
||||||
|
"FLUX.1-schnell_ae",
|
||||||
|
"black-forest-labs/FLUX.1-schnell::ae.safetensors",
|
||||||
|
ModelFormat.Checkpoint,
|
||||||
|
ModelType.VAE,
|
||||||
|
BaseModelType.Flux,
|
||||||
|
)
|
||||||
|
|
||||||
return FluxModelLoaderOutput(
|
return FluxModelLoaderOutput(
|
||||||
transformer=TransformerField(transformer=transformer),
|
transformer=TransformerField(transformer=transformer),
|
||||||
@ -198,33 +197,59 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
vae=VAEField(vae=vae),
|
vae=VAEField(vae=vae),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_model(self, context: InvocationContext, submodel:SubModelType) -> ModelIdentifierField:
|
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
||||||
match(submodel):
|
match submodel:
|
||||||
case SubModelType.Transformer:
|
case SubModelType.Transformer:
|
||||||
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||||
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
|
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
|
||||||
return self._install_model(context, submodel, "clip-vit-large-patch14", "openai/clip-vit-large-patch14", ModelFormat.Diffusers, ModelType.CLIPEmbed, BaseModelType.Any)
|
return self._install_model(
|
||||||
case SubModelType.TextEncoder2:
|
context,
|
||||||
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["text_encoder_name"], T5_ENCODER_MAP[self.t5_encoder]["text_encoder_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
|
submodel,
|
||||||
case SubModelType.Tokenizer2:
|
"clip-vit-large-patch14",
|
||||||
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["tokenizer_name"], T5_ENCODER_MAP[self.t5_encoder]["tokenizer_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
|
"openai/clip-vit-large-patch14",
|
||||||
|
ModelFormat.Diffusers,
|
||||||
|
ModelType.CLIPEmbed,
|
||||||
|
BaseModelType.Any,
|
||||||
|
)
|
||||||
|
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
|
||||||
|
return self._install_model(
|
||||||
|
context,
|
||||||
|
submodel,
|
||||||
|
T5_ENCODER_MAP[self.t5_encoder]["name"],
|
||||||
|
T5_ENCODER_MAP[self.t5_encoder]["repo"],
|
||||||
|
ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]),
|
||||||
|
ModelType.T5Encoder,
|
||||||
|
BaseModelType.Any,
|
||||||
|
)
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
|
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
|
||||||
|
|
||||||
def _install_model(self, context: InvocationContext, submodel:SubModelType, name: str, repo_id: str, format: ModelFormat, type: ModelType, base: BaseModelType):
|
def _install_model(
|
||||||
if (models := context.models.search_by_attrs(name=name, base=base, type=type)):
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
submodel: SubModelType,
|
||||||
|
name: str,
|
||||||
|
repo_id: str,
|
||||||
|
format: ModelFormat,
|
||||||
|
type: ModelType,
|
||||||
|
base: BaseModelType,
|
||||||
|
):
|
||||||
|
if models := context.models.search_by_attrs(name=name, base=base, type=type):
|
||||||
if len(models) != 1:
|
if len(models) != 1:
|
||||||
raise Exception(f"Multiple models detected for selected model with name {name}")
|
raise Exception(f"Multiple models detected for selected model with name {name}")
|
||||||
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
|
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
|
||||||
else:
|
else:
|
||||||
model_path = context.models.download_and_cache_model(repo_id)
|
model_path = context.models.download_and_cache_model(repo_id)
|
||||||
config = ModelRecordChanges(name = name, base = base, type=type, format=format)
|
config = ModelRecordChanges(name=name, base=base, type=type, format=format)
|
||||||
model_install_job = context.models.import_local_model(model_path=model_path, config=config)
|
model_install_job = context.models.import_local_model(model_path=model_path, config=config)
|
||||||
while not model_install_job.in_terminal_state:
|
while not model_install_job.in_terminal_state:
|
||||||
sleep(0.01)
|
sleep(0.01)
|
||||||
if not model_install_job.config_out:
|
if not model_install_job.config_out:
|
||||||
raise Exception(f"Failed to install {name}")
|
raise Exception(f"Failed to install {name}")
|
||||||
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(update={"submodel_type": submodel})
|
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(
|
||||||
|
update={"submodel_type": submodel}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"main_model_loader",
|
"main_model_loader",
|
||||||
|
@ -301,7 +301,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
for row in result:
|
for row in result:
|
||||||
try:
|
try:
|
||||||
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
|
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
|
||||||
except pydantic.ValidationError as e:
|
except pydantic.ValidationError:
|
||||||
# We catch this error so that the app can still run if there are invalid model configs in the database.
|
# We catch this error so that the app can still run if there are invalid model configs in the database.
|
||||||
# One reason that an invalid model config might be in the database is if someone had to rollback from a
|
# One reason that an invalid model config might be in the database is if someone had to rollback from a
|
||||||
# newer version of the app that added a new model type.
|
# newer version of the app that added a new model type.
|
||||||
|
@ -465,18 +465,20 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||||
|
|
||||||
def import_local_model(
|
def import_local_model(
|
||||||
self,
|
self,
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
config: Optional[ModelRecordChanges] = None,
|
config: Optional[ModelRecordChanges] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
inplace: Optional[bool] = False,
|
inplace: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
TODO: Fill out description of this method
|
TODO: Fill out description of this method
|
||||||
"""
|
"""
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise Exception("Models provided to import_local_model must already exist on disk")
|
raise Exception("Models provided to import_local_model must already exist on disk")
|
||||||
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, access_token=access_token, inplace=inplace)
|
return self._services.model_manager.install.heuristic_import(
|
||||||
|
str(model_path), config=config, access_token=access_token, inplace=inplace
|
||||||
|
)
|
||||||
|
|
||||||
def load_local_model(
|
def load_local_model(
|
||||||
self,
|
self,
|
||||||
|
@ -27,4 +27,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso
|
|||||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
@ -3,9 +3,15 @@ from dataclasses import dataclass
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
from invokeai.backend.flux.modules.layers import (
|
||||||
MLPEmbedder, SingleStreamBlock,
|
DoubleStreamBlock,
|
||||||
timestep_embedding)
|
EmbedND,
|
||||||
|
LastLayer,
|
||||||
|
MLPEmbedder,
|
||||||
|
SingleStreamBlock,
|
||||||
|
timestep_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FluxParams:
|
class FluxParams:
|
||||||
@ -35,9 +41,7 @@ class Flux(nn.Module):
|
|||||||
self.in_channels = params.in_channels
|
self.in_channels = params.in_channels
|
||||||
self.out_channels = self.in_channels
|
self.out_channels = self.in_channels
|
||||||
if params.hidden_size % params.num_heads != 0:
|
if params.hidden_size % params.num_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
|
||||||
)
|
|
||||||
pe_dim = params.hidden_size // params.num_heads
|
pe_dim = params.hidden_size // params.num_heads
|
||||||
if sum(params.axes_dim) != pe_dim:
|
if sum(params.axes_dim) != pe_dim:
|
||||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||||
@ -108,4 +112,4 @@ class Flux(nn.Module):
|
|||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
@ -309,4 +309,4 @@ class AutoEncoder(nn.Module):
|
|||||||
return self.decoder(z)
|
return self.decoder(z)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return self.decode(self.encode(x))
|
return self.decode(self.encode(x))
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformers import (PreTrainedModel, PreTrainedTokenizer)
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
class HFEncoder(nn.Module):
|
class HFEncoder(nn.Module):
|
||||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||||
@ -27,4 +28,4 @@ class HFEncoder(nn.Module):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
)
|
)
|
||||||
return outputs[self.output_key]
|
return outputs[self.output_key]
|
||||||
|
@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
|||||||
"""
|
"""
|
||||||
t = time_factor * t
|
t = time_factor * t
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||||
t.device
|
|
||||||
)
|
|
||||||
|
|
||||||
args = t[:, None].float() * freqs[None]
|
args = t[:, None].float() * freqs[None]
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
@ -250,4 +248,4 @@ class LastLayer(nn.Module):
|
|||||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
return x
|
return x
|
||||||
|
134
invokeai/backend/flux/sampling.py
Normal file
134
invokeai/backend/flux/sampling.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
import math
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from .model import Flux
|
||||||
|
from .modules.conditioner import HFEncoder
|
||||||
|
|
||||||
|
|
||||||
|
def get_noise(
|
||||||
|
num_samples: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
):
|
||||||
|
return torch.randn(
|
||||||
|
num_samples,
|
||||||
|
16,
|
||||||
|
# allow for packing
|
||||||
|
2 * math.ceil(height / 16),
|
||||||
|
2 * math.ceil(width / 16),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
generator=torch.Generator(device=device).manual_seed(seed),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||||
|
bs, c, h, w = img.shape
|
||||||
|
if bs == 1 and not isinstance(prompt, str):
|
||||||
|
bs = len(prompt)
|
||||||
|
|
||||||
|
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
if img.shape[0] == 1 and bs > 1:
|
||||||
|
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||||
|
|
||||||
|
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt = [prompt]
|
||||||
|
txt = t5(prompt)
|
||||||
|
if txt.shape[0] == 1 and bs > 1:
|
||||||
|
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||||
|
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||||
|
|
||||||
|
vec = clip(prompt)
|
||||||
|
if vec.shape[0] == 1 and bs > 1:
|
||||||
|
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"img": img,
|
||||||
|
"img_ids": img_ids.to(img.device),
|
||||||
|
"txt": txt.to(img.device),
|
||||||
|
"txt_ids": txt_ids.to(img.device),
|
||||||
|
"vec": vec.to(img.device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||||
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||||
|
m = (y2 - y1) / (x2 - x1)
|
||||||
|
b = y1 - m * x1
|
||||||
|
return lambda x: m * x + b
|
||||||
|
|
||||||
|
|
||||||
|
def get_schedule(
|
||||||
|
num_steps: int,
|
||||||
|
image_seq_len: int,
|
||||||
|
base_shift: float = 0.5,
|
||||||
|
max_shift: float = 1.15,
|
||||||
|
shift: bool = True,
|
||||||
|
) -> list[float]:
|
||||||
|
# extra step for zero
|
||||||
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||||
|
|
||||||
|
# shifting the schedule to favor high timesteps for higher signal images
|
||||||
|
if shift:
|
||||||
|
# eastimate mu based on linear estimation between two points
|
||||||
|
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||||
|
timesteps = time_shift(mu, 1.0, timesteps)
|
||||||
|
|
||||||
|
return timesteps.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def denoise(
|
||||||
|
model: Flux,
|
||||||
|
# model input
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
vec: Tensor,
|
||||||
|
# sampling parameters
|
||||||
|
timesteps: list[float],
|
||||||
|
guidance: float = 4.0,
|
||||||
|
):
|
||||||
|
# this is ignored for schnell
|
||||||
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
|
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False):
|
||||||
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
|
pred = model(
|
||||||
|
img=img,
|
||||||
|
img_ids=img_ids,
|
||||||
|
txt=txt,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
y=vec,
|
||||||
|
timesteps=t_vec,
|
||||||
|
guidance=guidance_vec,
|
||||||
|
)
|
||||||
|
|
||||||
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||||
|
h=math.ceil(height / 16),
|
||||||
|
w=math.ceil(width / 16),
|
||||||
|
ph=2,
|
||||||
|
pw=2,
|
||||||
|
)
|
@ -1,14 +1,17 @@
|
|||||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||||
"""Class for Flux model loading in InvokeAI."""
|
"""Class for Flux model loading in InvokeAI."""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
from safetensors.torch import load_file
|
from pathlib import Path
|
||||||
from typing import Optional, Any
|
from typing import Any, Optional
|
||||||
from transformers import T5EncoderModel, T5Tokenizer
|
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||||
|
|
||||||
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
from invokeai.backend.flux.model import Flux, FluxParams
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -19,20 +22,15 @@ from invokeai.backend.model_manager import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
CheckpointConfigBase,
|
CheckpointConfigBase,
|
||||||
MainCheckpointConfig,
|
|
||||||
CLIPEmbedDiffusersConfig,
|
CLIPEmbedDiffusersConfig,
|
||||||
|
MainCheckpointConfig,
|
||||||
T5EncoderConfig,
|
T5EncoderConfig,
|
||||||
VAECheckpointConfig,
|
VAECheckpointConfig,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config.config_default import get_config
|
|
||||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.flux.model import Flux, FluxParams
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams, AutoEncoder
|
|
||||||
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
|
|
||||||
T5Tokenizer)
|
|
||||||
|
|
||||||
app_config = get_config()
|
app_config = get_config()
|
||||||
|
|
||||||
@ -56,9 +54,9 @@ class FluxVAELoader(GenericDiffusersLoader):
|
|||||||
flux_conf = yaml.safe_load(stream)
|
flux_conf = yaml.safe_load(stream)
|
||||||
except:
|
except:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
|
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
|
||||||
filtered_data = {k: v for k, v in flux_conf['params']['ae_params'].items() if k in dataclass_fields}
|
filtered_data = {k: v for k, v in flux_conf["params"]["ae_params"].items() if k in dataclass_fields}
|
||||||
params = AutoEncoderParams(**filtered_data)
|
params = AutoEncoderParams(**filtered_data)
|
||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
@ -92,6 +90,7 @@ class ClipCheckpointModel(GenericDiffusersLoader):
|
|||||||
|
|
||||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||||
class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
||||||
"""Class to load main models."""
|
"""Class to load main models."""
|
||||||
@ -106,9 +105,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
|||||||
|
|
||||||
match submodel_type:
|
match submodel_type:
|
||||||
case SubModelType.Tokenizer2:
|
case SubModelType.Tokenizer2:
|
||||||
return T5Tokenizer.from_pretrained(Path(config.path), max_length=512)
|
return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512)
|
||||||
case SubModelType.TextEncoder2:
|
case SubModelType.TextEncoder2:
|
||||||
return T5EncoderModel.from_pretrained(Path(config.path))
|
return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer")
|
||||||
|
|
||||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||||
|
|
||||||
@ -148,7 +147,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
|
|||||||
params = None
|
params = None
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||||
filtered_data = {k: v for k, v in flux_conf['params'].items() if k in dataclass_fields}
|
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||||
params = FluxParams(**filtered_data)
|
params = FluxParams(**filtered_data)
|
||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
|
@ -39,11 +39,15 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
|||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(
|
||||||
|
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
|
||||||
|
)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(
|
||||||
|
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
|
||||||
|
)
|
||||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||||
"""Class to load main models."""
|
"""Class to load main models."""
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
from transformers import CLIPTokenizer, T5TokenizerFast, T5Tokenizer
|
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
|
||||||
|
|
||||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||||
@ -52,7 +52,10 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
|||||||
return model.calc_size()
|
return model.calc_size()
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
model,
|
model,
|
||||||
(T5TokenizerFast,T5Tokenizer,),
|
(
|
||||||
|
T5TokenizerFast,
|
||||||
|
T5Tokenizer,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
return len(model)
|
return len(model)
|
||||||
else:
|
else:
|
||||||
|
@ -56,7 +56,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
|
|||||||
},
|
},
|
||||||
BaseModelType.StableDiffusionXLRefiner: {
|
BaseModelType.StableDiffusionXLRefiner: {
|
||||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -132,7 +132,7 @@ class ModelProbe(object):
|
|||||||
fields = {}
|
fields = {}
|
||||||
|
|
||||||
model_path = model_path.resolve()
|
model_path = model_path.resolve()
|
||||||
|
|
||||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||||
model_info = None
|
model_info = None
|
||||||
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
|
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
|
||||||
@ -323,7 +323,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
if model_type is ModelType.Main:
|
if model_type is ModelType.Main:
|
||||||
if base_type == BaseModelType.Flux:
|
if base_type == BaseModelType.Flux:
|
||||||
config_file="flux/flux1-schnell.yaml"
|
config_file = "flux/flux1-schnell.yaml"
|
||||||
else:
|
else:
|
||||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||||
@ -727,6 +727,7 @@ class T5EncoderFolderProbe(FolderProbeBase):
|
|||||||
def get_format(self) -> ModelFormat:
|
def get_format(self) -> ModelFormat:
|
||||||
return ModelFormat.T5Encoder
|
return ModelFormat.T5Encoder
|
||||||
|
|
||||||
|
|
||||||
class ONNXFolderProbe(PipelineFolderProbe):
|
class ONNXFolderProbe(PipelineFolderProbe):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
# Due to the way the installer is set up, the configuration file for safetensors
|
# Due to the way the installer is set up, the configuration file for safetensors
|
||||||
|
Loading…
Reference in New Issue
Block a user