Remove dependency on flux config files

This commit is contained in:
Brandon Rising 2024-08-25 02:41:13 -04:00 committed by Brandon
parent 56b9906e2e
commit 70c278c810
7 changed files with 33 additions and 95 deletions

View File

@ -11,6 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig from invokeai.app.shared.models import FreeUConfig
@ -188,17 +189,13 @@ class FluxModelLoaderInvocation(BaseInvocation):
vae = self._get_model(context, SubModelType.VAE) vae = self._get_model(context, SubModelType.VAE)
transformer_config = context.models.get_config(transformer) transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase) assert isinstance(transformer_config, CheckpointConfigBase)
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)
return FluxModelLoaderOutput( return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer), transformer=TransformerField(transformer=transformer),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0), clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder), t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae), vae=VAEField(vae=vae),
max_seq_len=flux_conf["max_seq_len"], max_seq_len=max_seq_lengths[transformer_config.config_path],
) )
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField: def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:

View File

@ -2,6 +2,7 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Literal
from invokeai.backend.flux.model import FluxParams from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
@ -18,6 +19,25 @@ class ModelSpec:
repo_ae: str | None repo_ae: str | None
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-schnell": 256,
}
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
configs = { configs = {
"flux-dev": ModelSpec( "flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev", repo_id="black-forest-labs/FLUX.1-dev",

View File

@ -1,19 +1,18 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team # Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI.""" """Class for Flux model loading in InvokeAI."""
from dataclasses import fields
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Optional
import accelerate import accelerate
import torch import torch
import yaml
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux, FluxParams from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams from invokeai.backend.flux.util import configs, ae_params
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
AnyModelConfig, AnyModelConfig,
@ -58,17 +57,9 @@ class FluxVAELoader(ModelLoader):
if not isinstance(config, VAECheckpointConfig): if not isinstance(config, VAECheckpointConfig):
raise ValueError("Only VAECheckpointConfig models are currently supported here.") raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path) model_path = Path(config.path)
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = AutoEncoderParams(**filtered_data)
with SilenceWarnings(): with SilenceWarnings():
model = AutoEncoder(params) model = AutoEncoder(ae_params)
sd = load_file(model_path) sd = load_file(model_path)
model.load_state_dict(sd, assign=True) model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype) model.to(dtype=self._torch_dtype)
@ -182,14 +173,10 @@ class FluxCheckpointModel(ModelLoader):
) -> AnyModel: ) -> AnyModel:
if not isinstance(config, CheckpointConfigBase): if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.") raise ValueError("Only CheckpointConfigBase models are currently supported here.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)
match submodel_type: match submodel_type:
case SubModelType.Transformer: case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf) return self._load_from_singlefile(config)
raise ValueError( raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
@ -198,16 +185,12 @@ class FluxCheckpointModel(ModelLoader):
def _load_from_singlefile( def _load_from_singlefile(
self, self,
config: AnyModelConfig, config: AnyModelConfig,
flux_conf: Any,
) -> AnyModel: ) -> AnyModel:
assert isinstance(config, MainCheckpointConfig) assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path) model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = FluxParams(**filtered_data)
with SilenceWarnings(): with SilenceWarnings():
model = Flux(params) model = Flux(configs[config.config_path].params)
sd = load_file(model_path) sd = load_file(model_path)
model.load_state_dict(sd, assign=True) model.load_state_dict(sd, assign=True)
return model return model
@ -224,14 +207,10 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
) -> AnyModel: ) -> AnyModel:
if not isinstance(config, CheckpointConfigBase): if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.") raise ValueError("Only CheckpointConfigBase models are currently supported here.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)
match submodel_type: match submodel_type:
case SubModelType.Transformer: case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf) return self._load_from_singlefile(config)
raise ValueError( raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
@ -240,7 +219,6 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
def _load_from_singlefile( def _load_from_singlefile(
self, self,
config: AnyModelConfig, config: AnyModelConfig,
flux_conf: Any,
) -> AnyModel: ) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig) assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
if not bnb_available: if not bnb_available:
@ -248,13 +226,10 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
"The bnb modules are not available. Please install bitsandbytes if available on your platform." "The bnb modules are not available. Please install bitsandbytes if available on your platform."
) )
model_path = Path(config.path) model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = FluxParams(**filtered_data)
with SilenceWarnings(): with SilenceWarnings():
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
model = Flux(params) model = Flux(configs[config.config_path].params)
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path) sd = load_file(model_path)
model.load_state_dict(sd, assign=True) model.load_state_dict(sd, assign=True)

View File

@ -329,9 +329,9 @@ class ModelProbe(object):
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint state_dict = checkpoint.get("state_dict") or checkpoint
if "guidance_in.out_layer.weight" in state_dict: if "guidance_in.out_layer.weight" in state_dict:
config_file = "flux/flux1-dev.yaml" config_file = "flux-dev"
else: else:
config_file = "flux/flux1-schnell.yaml" config_file = "flux-schnell"
else: else:
config_file = LEGACY_CONFIGS[base_type][variant_type] config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models if isinstance(config_file, dict): # need another tier for sd-2.x models

View File

@ -1,19 +0,0 @@
repo_id: "black-forest-labs/FLUX.1-dev"
repo_ae: "ae.safetensors"
max_seq_len: 512
params:
in_channels: 64
vec_in_dim: 768
context_in_dim: 4096
hidden_size: 3072
mlp_ratio: 4.0
num_heads: 24
depth: 19
depth_single_blocks: 38
axes_dim:
- 16
- 56
- 56
theta: 10_000
qkv_bias: True
guidance_embed: True

View File

@ -1,19 +0,0 @@
repo_id: "black-forest-labs/FLUX.1-schnell"
repo_ae: "ae.safetensors"
max_seq_len: 256
params:
in_channels: 64
vec_in_dim: 768
context_in_dim: 4096
hidden_size: 3072
mlp_ratio: 4.0
num_heads: 24
depth: 19
depth_single_blocks: 38
axes_dim:
- 16
- 56
- 56
theta: 10_000
qkv_bias: True
guidance_embed: False

View File

@ -1,16 +0,0 @@
repo_id: "black-forest-labs/FLUX.1-schnell"
repo_path: "ae.safetensors"
params:
resolution: 256
in_channels: 3
ch: 128
out_ch: 3
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
z_channels: 16
scale_factor: 0.3611
shift_factor: 0.1159