From 8b0b496c2d762a51ad8303b4e17a340ea532e0c2 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 21 Aug 2024 12:37:25 -0400 Subject: [PATCH] More flux loader cleanup --- .../model_manager/load/model_loaders/flux.py | 20 +++++++++---------- invokeai/backend/model_manager/probe.py | 1 + 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 58264ebc25..ebc3333eea 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -64,7 +64,7 @@ class FluxVAELoader(ModelLoader): with SilenceWarnings(): model = AutoEncoder(params) sd = load_file(model_path) - model.load_state_dict(sd, strict=False, assign=True) + model.load_state_dict(sd, assign=True) return model @@ -83,11 +83,11 @@ class ClipCheckpointModel(ModelLoader): match submodel_type: case SubModelType.Tokenizer: - return CLIPTokenizer.from_pretrained(config.path, max_length=77) + return CLIPTokenizer.from_pretrained(config.path) case SubModelType.TextEncoder: return CLIPTextModel.from_pretrained(config.path) - raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.") + raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}") @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b) @@ -108,7 +108,7 @@ class T5Encoder8bCheckpointModel(ModelLoader): case SubModelType.TextEncoder2: return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2") - raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.") + raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}") @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder) @@ -131,7 +131,7 @@ class T5EncoderCheckpointModel(ModelLoader): Path(config.path) / "text_encoder_2" ) # TODO: Fix hf subfolder install - raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.") + raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}") @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint) @@ -154,7 +154,7 @@ class FluxCheckpointModel(ModelLoader): case SubModelType.Transformer: return self._load_from_singlefile(config, flux_conf) - raise ValueError("Only Transformer submodels are currently supported.") + raise ValueError(f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}") def _load_from_singlefile( self, @@ -162,7 +162,6 @@ class FluxCheckpointModel(ModelLoader): flux_conf: Any, ) -> AnyModel: assert isinstance(config, MainCheckpointConfig) - params = None model_path = Path(config.path) dataclass_fields = {f.name for f in fields(FluxParams)} filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields} @@ -171,7 +170,7 @@ class FluxCheckpointModel(ModelLoader): with SilenceWarnings(): model = Flux(params) sd = load_file(model_path) - model.load_state_dict(sd, strict=False, assign=True) + model.load_state_dict(sd, assign=True) return model @@ -195,7 +194,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): case SubModelType.Transformer: return self._load_from_singlefile(config, flux_conf) - raise ValueError("Only Transformer submodels are currently supported.") + raise ValueError(f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}") def _load_from_singlefile( self, @@ -203,7 +202,6 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): flux_conf: Any, ) -> AnyModel: assert isinstance(config, MainBnbQuantized4bCheckpointConfig) - params = None model_path = Path(config.path) dataclass_fields = {f.name for f in fields(FluxParams)} filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields} @@ -214,5 +212,5 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): model = Flux(params) model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) sd = load_file(model_path) - model.load_state_dict(sd, strict=False, assign=True) + model.load_state_dict(sd, assign=True) return model diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index a3364da769..778dd583e5 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -224,6 +224,7 @@ class ModelProbe(object): for key in [str(k) for k in ckpt.keys()]: if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")): + # Keys starting with double_blocks are associated with Flux models return ModelType.Main elif key.startswith(("encoder.conv_in", "decoder.conv_in")): return ModelType.VAE