mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
More flux loader cleanup
This commit is contained in:
parent
ada483f65e
commit
8b0b496c2d
@ -64,7 +64,7 @@ class FluxVAELoader(ModelLoader):
|
|||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model = AutoEncoder(params)
|
model = AutoEncoder(params)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, strict=False, assign=True)
|
model.load_state_dict(sd, assign=True)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -83,11 +83,11 @@ class ClipCheckpointModel(ModelLoader):
|
|||||||
|
|
||||||
match submodel_type:
|
match submodel_type:
|
||||||
case SubModelType.Tokenizer:
|
case SubModelType.Tokenizer:
|
||||||
return CLIPTokenizer.from_pretrained(config.path, max_length=77)
|
return CLIPTokenizer.from_pretrained(config.path)
|
||||||
case SubModelType.TextEncoder:
|
case SubModelType.TextEncoder:
|
||||||
return CLIPTextModel.from_pretrained(config.path)
|
return CLIPTextModel.from_pretrained(config.path)
|
||||||
|
|
||||||
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
|
raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
|
||||||
@ -108,7 +108,7 @@ class T5Encoder8bCheckpointModel(ModelLoader):
|
|||||||
case SubModelType.TextEncoder2:
|
case SubModelType.TextEncoder2:
|
||||||
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
||||||
|
|
||||||
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
|
raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||||
@ -131,7 +131,7 @@ class T5EncoderCheckpointModel(ModelLoader):
|
|||||||
Path(config.path) / "text_encoder_2"
|
Path(config.path) / "text_encoder_2"
|
||||||
) # TODO: Fix hf subfolder install
|
) # TODO: Fix hf subfolder install
|
||||||
|
|
||||||
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
|
raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
@ -154,7 +154,7 @@ class FluxCheckpointModel(ModelLoader):
|
|||||||
case SubModelType.Transformer:
|
case SubModelType.Transformer:
|
||||||
return self._load_from_singlefile(config, flux_conf)
|
return self._load_from_singlefile(config, flux_conf)
|
||||||
|
|
||||||
raise ValueError("Only Transformer submodels are currently supported.")
|
raise ValueError(f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
|
||||||
|
|
||||||
def _load_from_singlefile(
|
def _load_from_singlefile(
|
||||||
self,
|
self,
|
||||||
@ -162,7 +162,6 @@ class FluxCheckpointModel(ModelLoader):
|
|||||||
flux_conf: Any,
|
flux_conf: Any,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
assert isinstance(config, MainCheckpointConfig)
|
assert isinstance(config, MainCheckpointConfig)
|
||||||
params = None
|
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||||
@ -171,7 +170,7 @@ class FluxCheckpointModel(ModelLoader):
|
|||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model = Flux(params)
|
model = Flux(params)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, strict=False, assign=True)
|
model.load_state_dict(sd, assign=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -195,7 +194,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|||||||
case SubModelType.Transformer:
|
case SubModelType.Transformer:
|
||||||
return self._load_from_singlefile(config, flux_conf)
|
return self._load_from_singlefile(config, flux_conf)
|
||||||
|
|
||||||
raise ValueError("Only Transformer submodels are currently supported.")
|
raise ValueError(f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
|
||||||
|
|
||||||
def _load_from_singlefile(
|
def _load_from_singlefile(
|
||||||
self,
|
self,
|
||||||
@ -203,7 +202,6 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|||||||
flux_conf: Any,
|
flux_conf: Any,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
||||||
params = None
|
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||||
@ -214,5 +212,5 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|||||||
model = Flux(params)
|
model = Flux(params)
|
||||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, strict=False, assign=True)
|
model.load_state_dict(sd, assign=True)
|
||||||
return model
|
return model
|
||||||
|
@ -224,6 +224,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
for key in [str(k) for k in ckpt.keys()]:
|
for key in [str(k) for k in ckpt.keys()]:
|
||||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
|
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
|
||||||
|
# Keys starting with double_blocks are associated with Flux models
|
||||||
return ModelType.Main
|
return ModelType.Main
|
||||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||||
return ModelType.VAE
|
return ModelType.VAE
|
||||||
|
Loading…
Reference in New Issue
Block a user