Run ruff, setup initial text to image node

This commit is contained in:
Brandon Rising 2024-08-19 10:14:58 -04:00 committed by Brandon
parent 436f18ff55
commit 1bd90e0fd4
15 changed files with 291 additions and 124 deletions

View File

@ -1,8 +1,4 @@
import torch
from einops import repeat
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@ -10,9 +6,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@invocation(

View File

@ -1,12 +1,6 @@
from typing import Literal
import accelerate
import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from einops import rearrange, repeat
from PIL import Image
from safetensors.torch import load_file
from transformers.models.auto import AutoModelForTextEncoding
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
@ -20,23 +14,12 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
TFluxModelKeys = Literal["flux-schnell"]
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
base_class = FluxTransformer2DModel
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
auto_class = AutoModelForTextEncoding
@invocation(
"flux_text_to_image",
@ -78,7 +61,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(flux_conditioning, FLUXConditioningInfo)
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
image = self._run_vae_decoding(context, latents)
image = self._run_vae_decoding(context, flux_ae_path, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
@ -89,14 +72,40 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
t5_embeddings: torch.Tensor,
):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = TorchDevice.choose_torch_dtype()
# Prepare input noise.
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
# CPU RNG?
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
img, img_ids = self._prepare_latent_img_patches(x)
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
with transformer_info as transformer:
assert isinstance(transformer, FluxTransformer2DModel)
assert isinstance(transformer, Flux)
x = denoise(
model=transformer,
@ -144,21 +153,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
) -> Image.Image:
vae_info = context.models.load(self.vae.vae)
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)
assert isinstance(vae, AutoEncoder)
# TODO(ryand): Test that this works with both float16 and bfloat16.
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
img = vae.decode(latents)
img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
latents = flux_pipeline_with_vae._unpack_latents(
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
)
latents = (
latents / flux_pipeline_with_vae.vae.config.scaling_factor
) + flux_pipeline_with_vae.vae.config.shift_factor
latents = latents.to(dtype=vae.dtype)
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
assert isinstance(image, Image.Image)
return image
return img_pil

View File

@ -1,6 +1,6 @@
import copy
from time import sleep
from typing import List, Optional, Literal, Dict
from typing import Dict, List, Literal, Optional
from pydantic import BaseModel, Field
@ -12,10 +12,10 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType, ModelFormat
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
class ModelIdentifierField(BaseModel):
@ -132,31 +132,22 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model)
T5_ENCODER_OPTIONS = Literal["base", "16b_quantized", "8b_quantized"]
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
"base": {
"text_encoder_repo": "black-forest-labs/FLUX.1-schnell::text_encoder_2",
"tokenizer_repo": "black-forest-labs/FLUX.1-schnell::tokenizer_2",
"text_encoder_name": "FLUX.1-schnell_text_encoder_2",
"tokenizer_name": "FLUX.1-schnell_tokenizer_2",
"repo": "invokeai/flux_dev::t5_xxl_encoder/base",
"name": "t5_base_encoder",
"format": ModelFormat.T5Encoder,
},
"8b_quantized": {
"text_encoder_repo": "hf_repo1",
"tokenizer_repo": "hf_repo1",
"text_encoder_name": "hf_repo1",
"tokenizer_name": "hf_repo1",
"format": ModelFormat.T5Encoder8b,
},
"4b_quantized": {
"text_encoder_repo": "hf_repo2",
"tokenizer_repo": "hf_repo2",
"text_encoder_name": "hf_repo2",
"tokenizer_name": "hf_repo2",
"format": ModelFormat.T5Encoder8b,
"repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized",
"name": "t5_8b_quantized_encoder",
"format": ModelFormat.T5Encoder,
},
}
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
@ -189,7 +180,15 @@ class FluxModelLoaderInvocation(BaseInvocation):
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
vae = self._install_model(context, SubModelType.VAE, "FLUX.1-schnell_ae", "black-forest-labs/FLUX.1-schnell::ae.safetensors", ModelFormat.Checkpoint, ModelType.VAE, BaseModelType.Flux)
vae = self._install_model(
context,
SubModelType.VAE,
"FLUX.1-schnell_ae",
"black-forest-labs/FLUX.1-schnell::ae.safetensors",
ModelFormat.Checkpoint,
ModelType.VAE,
BaseModelType.Flux,
)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
@ -198,33 +197,59 @@ class FluxModelLoaderInvocation(BaseInvocation):
vae=VAEField(vae=vae),
)
def _get_model(self, context: InvocationContext, submodel:SubModelType) -> ModelIdentifierField:
match(submodel):
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
match submodel:
case SubModelType.Transformer:
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
return self._install_model(context, submodel, "clip-vit-large-patch14", "openai/clip-vit-large-patch14", ModelFormat.Diffusers, ModelType.CLIPEmbed, BaseModelType.Any)
case SubModelType.TextEncoder2:
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["text_encoder_name"], T5_ENCODER_MAP[self.t5_encoder]["text_encoder_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
case SubModelType.Tokenizer2:
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["tokenizer_name"], T5_ENCODER_MAP[self.t5_encoder]["tokenizer_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
return self._install_model(
context,
submodel,
"clip-vit-large-patch14",
"openai/clip-vit-large-patch14",
ModelFormat.Diffusers,
ModelType.CLIPEmbed,
BaseModelType.Any,
)
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
return self._install_model(
context,
submodel,
T5_ENCODER_MAP[self.t5_encoder]["name"],
T5_ENCODER_MAP[self.t5_encoder]["repo"],
ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]),
ModelType.T5Encoder,
BaseModelType.Any,
)
case _:
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
def _install_model(self, context: InvocationContext, submodel:SubModelType, name: str, repo_id: str, format: ModelFormat, type: ModelType, base: BaseModelType):
if (models := context.models.search_by_attrs(name=name, base=base, type=type)):
def _install_model(
self,
context: InvocationContext,
submodel: SubModelType,
name: str,
repo_id: str,
format: ModelFormat,
type: ModelType,
base: BaseModelType,
):
if models := context.models.search_by_attrs(name=name, base=base, type=type):
if len(models) != 1:
raise Exception(f"Multiple models detected for selected model with name {name}")
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
else:
model_path = context.models.download_and_cache_model(repo_id)
config = ModelRecordChanges(name = name, base = base, type=type, format=format)
config = ModelRecordChanges(name=name, base=base, type=type, format=format)
model_install_job = context.models.import_local_model(model_path=model_path, config=config)
while not model_install_job.in_terminal_state:
sleep(0.01)
if not model_install_job.config_out:
raise Exception(f"Failed to install {name}")
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(update={"submodel_type": submodel})
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(
update={"submodel_type": submodel}
)
@invocation(
"main_model_loader",

View File

@ -301,7 +301,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
for row in result:
try:
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
except pydantic.ValidationError as e:
except pydantic.ValidationError:
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type.

View File

@ -465,18 +465,20 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.install.download_and_cache_model(source=source)
def import_local_model(
self,
model_path: Path,
config: Optional[ModelRecordChanges] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
self,
model_path: Path,
config: Optional[ModelRecordChanges] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
):
"""
TODO: Fill out description of this method
"""
if not model_path.exists():
raise Exception("Models provided to import_local_model must already exist on disk")
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, access_token=access_token, inplace=inplace)
return self._services.model_manager.install.heuristic_import(
str(model_path), config=config, access_token=access_token, inplace=inplace
)
def load_local_model(
self,

View File

@ -3,9 +3,15 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class FluxParams:
@ -35,9 +41,7 @@ class Flux(nn.Module):
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")

View File

@ -1,5 +1,6 @@
from torch import Tensor, nn
from transformers import (PreTrainedModel, PreTrainedTokenizer)
from transformers import PreTrainedModel, PreTrainedTokenizer
class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):

View File

@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
t.device
)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)

View File

@ -0,0 +1,134 @@
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from .model import Flux
from .modules.conditioner import HFEncoder
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)

View File

@ -1,14 +1,17 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""
from pathlib import Path
import yaml
from dataclasses import fields
from safetensors.torch import load_file
from typing import Optional, Any
from transformers import T5EncoderModel, T5Tokenizer
from pathlib import Path
from typing import Any, Optional
import yaml
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux, FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@ -19,20 +22,15 @@ from invokeai.backend.model_manager import (
)
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
MainCheckpointConfig,
CLIPEmbedDiffusersConfig,
MainCheckpointConfig,
T5EncoderConfig,
VAECheckpointConfig,
)
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.flux.model import Flux, FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams, AutoEncoder
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
T5Tokenizer)
from invokeai.backend.util.silence_warnings import SilenceWarnings
app_config = get_config()
@ -58,7 +56,7 @@ class FluxVAELoader(GenericDiffusersLoader):
raise
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
filtered_data = {k: v for k, v in flux_conf['params']['ae_params'].items() if k in dataclass_fields}
filtered_data = {k: v for k, v in flux_conf["params"]["ae_params"].items() if k in dataclass_fields}
params = AutoEncoderParams(**filtered_data)
with SilenceWarnings():
@ -92,6 +90,7 @@ class ClipCheckpointModel(GenericDiffusersLoader):
raise Exception("Only Checkpoint Flux models are currently supported.")
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
class T5EncoderCheckpointModel(GenericDiffusersLoader):
"""Class to load main models."""
@ -106,9 +105,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path), max_length=512)
return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512)
case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path))
return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer")
raise Exception("Only Checkpoint Flux models are currently supported.")
@ -148,7 +147,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
params = None
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf['params'].items() if k in dataclass_fields}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = FluxParams(**filtered_data)
with SilenceWarnings():

View File

@ -39,11 +39,15 @@ VARIANT_TO_IN_CHANNEL_MAP = {
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
)
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""

View File

@ -9,7 +9,7 @@ from typing import Optional
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer, T5TokenizerFast, T5Tokenizer
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
@ -52,7 +52,10 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
return model.calc_size()
elif isinstance(
model,
(T5TokenizerFast,T5Tokenizer,),
(
T5TokenizerFast,
T5Tokenizer,
),
):
return len(model)
else:

View File

@ -56,7 +56,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
}
},
}
@ -323,7 +323,7 @@ class ModelProbe(object):
if model_type is ModelType.Main:
if base_type == BaseModelType.Flux:
config_file="flux/flux1-schnell.yaml"
config_file = "flux/flux1-schnell.yaml"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
@ -727,6 +727,7 @@ class T5EncoderFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat.T5Encoder
class ONNXFolderProbe(PipelineFolderProbe):
def get_base_type(self) -> BaseModelType:
# Due to the way the installer is set up, the configuration file for safetensors