More flux loader cleanup

This commit is contained in:
Brandon Rising 2024-08-21 12:37:25 -04:00
parent ada483f65e
commit 8b0b496c2d
2 changed files with 10 additions and 11 deletions

View File

@ -64,7 +64,7 @@ class FluxVAELoader(ModelLoader):
with SilenceWarnings():
model = AutoEncoder(params)
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
model.load_state_dict(sd, assign=True)
return model
@ -83,11 +83,11 @@ class ClipCheckpointModel(ModelLoader):
match submodel_type:
case SubModelType.Tokenizer:
return CLIPTokenizer.from_pretrained(config.path, max_length=77)
return CLIPTokenizer.from_pretrained(config.path)
case SubModelType.TextEncoder:
return CLIPTextModel.from_pretrained(config.path)
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
@ -108,7 +108,7 @@ class T5Encoder8bCheckpointModel(ModelLoader):
case SubModelType.TextEncoder2:
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
@ -131,7 +131,7 @@ class T5EncoderCheckpointModel(ModelLoader):
Path(config.path) / "text_encoder_2"
) # TODO: Fix hf subfolder install
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ -154,7 +154,7 @@ class FluxCheckpointModel(ModelLoader):
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)
raise ValueError("Only Transformer submodels are currently supported.")
raise ValueError(f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
def _load_from_singlefile(
self,
@ -162,7 +162,6 @@ class FluxCheckpointModel(ModelLoader):
flux_conf: Any,
) -> AnyModel:
assert isinstance(config, MainCheckpointConfig)
params = None
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
@ -171,7 +170,7 @@ class FluxCheckpointModel(ModelLoader):
with SilenceWarnings():
model = Flux(params)
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
model.load_state_dict(sd, assign=True)
return model
@ -195,7 +194,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)
raise ValueError("Only Transformer submodels are currently supported.")
raise ValueError(f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
def _load_from_singlefile(
self,
@ -203,7 +202,6 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
flux_conf: Any,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
params = None
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
@ -214,5 +212,5 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
model = Flux(params)
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
model.load_state_dict(sd, assign=True)
return model

View File

@ -224,6 +224,7 @@ class ModelProbe(object):
for key in [str(k) for k in ckpt.keys()]:
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
# Keys starting with double_blocks are associated with Flux models
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE