mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove dependency on flux config files
This commit is contained in:
parent
56b9906e2e
commit
70c278c810
@ -11,6 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
@ -188,17 +189,13 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
vae = self._get_model(context, SubModelType.VAE)
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
|
||||
config_path = legacy_config_path.as_posix()
|
||||
with open(config_path, "r") as stream:
|
||||
flux_conf = yaml.safe_load(stream)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=flux_conf["max_seq_len"],
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Literal
|
||||
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
||||
@ -18,6 +19,25 @@ class ModelSpec:
|
||||
repo_ae: str | None
|
||||
|
||||
|
||||
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||
"flux-dev": 512,
|
||||
"flux-schnell": 256,
|
||||
}
|
||||
|
||||
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
|
||||
configs = {
|
||||
"flux-dev": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-dev",
|
||||
|
@ -1,19 +1,18 @@
|
||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||
"""Class for Flux model loading in InvokeAI."""
|
||||
|
||||
from dataclasses import fields
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import yaml
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.model import Flux, FluxParams
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import configs, ae_params
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@ -58,17 +57,9 @@ class FluxVAELoader(ModelLoader):
|
||||
if not isinstance(config, VAECheckpointConfig):
|
||||
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
|
||||
model_path = Path(config.path)
|
||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||
config_path = legacy_config_path.as_posix()
|
||||
with open(config_path, "r") as stream:
|
||||
flux_conf = yaml.safe_load(stream)
|
||||
|
||||
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
|
||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||
params = AutoEncoderParams(**filtered_data)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = AutoEncoder(params)
|
||||
model = AutoEncoder(ae_params)
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
@ -182,14 +173,10 @@ class FluxCheckpointModel(ModelLoader):
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||
config_path = legacy_config_path.as_posix()
|
||||
with open(config_path, "r") as stream:
|
||||
flux_conf = yaml.safe_load(stream)
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config, flux_conf)
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
@ -198,16 +185,12 @@ class FluxCheckpointModel(ModelLoader):
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
flux_conf: Any,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||
params = FluxParams(**filtered_data)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = Flux(params)
|
||||
model = Flux(configs[config.config_path].params)
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
@ -224,14 +207,10 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||
config_path = legacy_config_path.as_posix()
|
||||
with open(config_path, "r") as stream:
|
||||
flux_conf = yaml.safe_load(stream)
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config, flux_conf)
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
@ -240,7 +219,6 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
flux_conf: Any,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
||||
if not bnb_available:
|
||||
@ -248,13 +226,10 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
)
|
||||
model_path = Path(config.path)
|
||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||
params = FluxParams(**filtered_data)
|
||||
|
||||
with SilenceWarnings():
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params)
|
||||
model = Flux(configs[config.config_path].params)
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
|
@ -329,9 +329,9 @@ class ModelProbe(object):
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
if "guidance_in.out_layer.weight" in state_dict:
|
||||
config_file = "flux/flux1-dev.yaml"
|
||||
config_file = "flux-dev"
|
||||
else:
|
||||
config_file = "flux/flux1-schnell.yaml"
|
||||
config_file = "flux-schnell"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
|
@ -1,19 +0,0 @@
|
||||
repo_id: "black-forest-labs/FLUX.1-dev"
|
||||
repo_ae: "ae.safetensors"
|
||||
max_seq_len: 512
|
||||
params:
|
||||
in_channels: 64
|
||||
vec_in_dim: 768
|
||||
context_in_dim: 4096
|
||||
hidden_size: 3072
|
||||
mlp_ratio: 4.0
|
||||
num_heads: 24
|
||||
depth: 19
|
||||
depth_single_blocks: 38
|
||||
axes_dim:
|
||||
- 16
|
||||
- 56
|
||||
- 56
|
||||
theta: 10_000
|
||||
qkv_bias: True
|
||||
guidance_embed: True
|
@ -1,19 +0,0 @@
|
||||
repo_id: "black-forest-labs/FLUX.1-schnell"
|
||||
repo_ae: "ae.safetensors"
|
||||
max_seq_len: 256
|
||||
params:
|
||||
in_channels: 64
|
||||
vec_in_dim: 768
|
||||
context_in_dim: 4096
|
||||
hidden_size: 3072
|
||||
mlp_ratio: 4.0
|
||||
num_heads: 24
|
||||
depth: 19
|
||||
depth_single_blocks: 38
|
||||
axes_dim:
|
||||
- 16
|
||||
- 56
|
||||
- 56
|
||||
theta: 10_000
|
||||
qkv_bias: True
|
||||
guidance_embed: False
|
@ -1,16 +0,0 @@
|
||||
repo_id: "black-forest-labs/FLUX.1-schnell"
|
||||
repo_path: "ae.safetensors"
|
||||
params:
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
ch: 128
|
||||
out_ch: 3
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
z_channels: 16
|
||||
scale_factor: 0.3611
|
||||
shift_factor: 0.1159
|
Loading…
Reference in New Issue
Block a user