mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
More flux loader cleanup
This commit is contained in:
parent
df9445c351
commit
72398350b4
@ -64,7 +64,7 @@ class FluxVAELoader(ModelLoader):
|
||||
with SilenceWarnings():
|
||||
model = AutoEncoder(params)
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, strict=False, assign=True)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
|
||||
return model
|
||||
|
||||
@ -83,11 +83,11 @@ class ClipCheckpointModel(ModelLoader):
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer:
|
||||
return CLIPTokenizer.from_pretrained(config.path, max_length=77)
|
||||
return CLIPTokenizer.from_pretrained(config.path)
|
||||
case SubModelType.TextEncoder:
|
||||
return CLIPTextModel.from_pretrained(config.path)
|
||||
|
||||
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
|
||||
raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
|
||||
@ -108,7 +108,7 @@ class T5Encoder8bCheckpointModel(ModelLoader):
|
||||
case SubModelType.TextEncoder2:
|
||||
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
||||
|
||||
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
|
||||
raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||
@ -131,7 +131,7 @@ class T5EncoderCheckpointModel(ModelLoader):
|
||||
Path(config.path) / "text_encoder_2"
|
||||
) # TODO: Fix hf subfolder install
|
||||
|
||||
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
|
||||
raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ -154,7 +154,7 @@ class FluxCheckpointModel(ModelLoader):
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config, flux_conf)
|
||||
|
||||
raise ValueError("Only Transformer submodels are currently supported.")
|
||||
raise ValueError(f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
@ -162,7 +162,6 @@ class FluxCheckpointModel(ModelLoader):
|
||||
flux_conf: Any,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
params = None
|
||||
model_path = Path(config.path)
|
||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||
@ -171,7 +170,7 @@ class FluxCheckpointModel(ModelLoader):
|
||||
with SilenceWarnings():
|
||||
model = Flux(params)
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, strict=False, assign=True)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
|
||||
@ -195,7 +194,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config, flux_conf)
|
||||
|
||||
raise ValueError("Only Transformer submodels are currently supported.")
|
||||
raise ValueError(f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
@ -203,7 +202,6 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
flux_conf: Any,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
||||
params = None
|
||||
model_path = Path(config.path)
|
||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||
@ -214,5 +212,5 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
model = Flux(params)
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, strict=False, assign=True)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
@ -224,6 +224,7 @@ class ModelProbe(object):
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
|
||||
# Keys starting with double_blocks are associated with Flux models
|
||||
return ModelType.Main
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
|
Loading…
Reference in New Issue
Block a user