mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Various styling and exception type updates
This commit is contained in:
parent
87b7a2e39b
commit
df9445c351
@ -183,7 +183,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
model_key = self.model.key
|
model_key = self.model.key
|
||||||
|
|
||||||
if not context.models.exists(model_key):
|
if not context.models.exists(model_key):
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
raise ValueError(f"Unknown model: {model_key}")
|
||||||
transformer = self._get_model(context, SubModelType.Transformer)
|
transformer = self._get_model(context, SubModelType.Transformer)
|
||||||
tokenizer = self._get_model(context, SubModelType.Tokenizer)
|
tokenizer = self._get_model(context, SubModelType.Tokenizer)
|
||||||
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
||||||
@ -203,10 +203,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
|
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
|
||||||
config_path = legacy_config_path.as_posix()
|
config_path = legacy_config_path.as_posix()
|
||||||
with open(config_path, "r") as stream:
|
with open(config_path, "r") as stream:
|
||||||
try:
|
|
||||||
flux_conf = yaml.safe_load(stream)
|
flux_conf = yaml.safe_load(stream)
|
||||||
except:
|
|
||||||
raise
|
|
||||||
|
|
||||||
return FluxModelLoaderOutput(
|
return FluxModelLoaderOutput(
|
||||||
transformer=TransformerField(transformer=transformer),
|
transformer=TransformerField(transformer=transformer),
|
||||||
|
@ -484,7 +484,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
ModelInstallJob object defining the install job to be used in tracking the job
|
ModelInstallJob object defining the install job to be used in tracking the job
|
||||||
"""
|
"""
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise Exception("Models provided to import_local_model must already exist on disk")
|
raise ValueError(f"Models provided to import_local_model must already exist on disk at {model_path.as_posix()}")
|
||||||
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, inplace=inplace)
|
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, inplace=inplace)
|
||||||
|
|
||||||
def load_local_model(
|
def load_local_model(
|
||||||
|
@ -49,29 +49,24 @@ class FluxVAELoader(ModelLoader):
|
|||||||
config: AnyModelConfig,
|
config: AnyModelConfig,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if isinstance(config, VAECheckpointConfig):
|
if not isinstance(config, VAECheckpointConfig):
|
||||||
|
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
load_class = AutoEncoder
|
|
||||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||||
config_path = legacy_config_path.as_posix()
|
config_path = legacy_config_path.as_posix()
|
||||||
with open(config_path, "r") as stream:
|
with open(config_path, "r") as stream:
|
||||||
try:
|
|
||||||
flux_conf = yaml.safe_load(stream)
|
flux_conf = yaml.safe_load(stream)
|
||||||
except:
|
|
||||||
raise
|
|
||||||
|
|
||||||
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
|
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
|
||||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||||
params = AutoEncoderParams(**filtered_data)
|
params = AutoEncoderParams(**filtered_data)
|
||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model = load_class(params)
|
model = AutoEncoder(params)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, strict=False, assign=True)
|
model.load_state_dict(sd, strict=False, assign=True)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
else:
|
|
||||||
return super()._load_model(config, submodel_type)
|
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
|
||||||
@ -84,7 +79,7 @@ class ClipCheckpointModel(ModelLoader):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if not isinstance(config, CLIPEmbedDiffusersConfig):
|
if not isinstance(config, CLIPEmbedDiffusersConfig):
|
||||||
raise Exception("Only CLIPEmbedDiffusersConfig models are currently supported here.")
|
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
|
||||||
|
|
||||||
match submodel_type:
|
match submodel_type:
|
||||||
case SubModelType.Tokenizer:
|
case SubModelType.Tokenizer:
|
||||||
@ -92,7 +87,7 @@ class ClipCheckpointModel(ModelLoader):
|
|||||||
case SubModelType.TextEncoder:
|
case SubModelType.TextEncoder:
|
||||||
return CLIPTextModel.from_pretrained(config.path)
|
return CLIPTextModel.from_pretrained(config.path)
|
||||||
|
|
||||||
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
|
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
|
||||||
@ -105,7 +100,7 @@ class T5Encoder8bCheckpointModel(ModelLoader):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if not isinstance(config, T5Encoder8bConfig):
|
if not isinstance(config, T5Encoder8bConfig):
|
||||||
raise Exception("Only T5Encoder8bConfig models are currently supported here.")
|
raise ValueError("Only T5Encoder8bConfig models are currently supported here.")
|
||||||
|
|
||||||
match submodel_type:
|
match submodel_type:
|
||||||
case SubModelType.Tokenizer2:
|
case SubModelType.Tokenizer2:
|
||||||
@ -113,7 +108,7 @@ class T5Encoder8bCheckpointModel(ModelLoader):
|
|||||||
case SubModelType.TextEncoder2:
|
case SubModelType.TextEncoder2:
|
||||||
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
||||||
|
|
||||||
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
|
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||||
@ -126,7 +121,7 @@ class T5EncoderCheckpointModel(ModelLoader):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if not isinstance(config, T5EncoderConfig):
|
if not isinstance(config, T5EncoderConfig):
|
||||||
raise Exception("Only T5EncoderConfig models are currently supported here.")
|
raise ValueError("Only T5EncoderConfig models are currently supported here.")
|
||||||
|
|
||||||
match submodel_type:
|
match submodel_type:
|
||||||
case SubModelType.Tokenizer2:
|
case SubModelType.Tokenizer2:
|
||||||
@ -136,7 +131,7 @@ class T5EncoderCheckpointModel(ModelLoader):
|
|||||||
Path(config.path) / "text_encoder_2"
|
Path(config.path) / "text_encoder_2"
|
||||||
) # TODO: Fix hf subfolder install
|
) # TODO: Fix hf subfolder install
|
||||||
|
|
||||||
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
|
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
@ -149,20 +144,17 @@ class FluxCheckpointModel(ModelLoader):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if not isinstance(config, CheckpointConfigBase):
|
if not isinstance(config, CheckpointConfigBase):
|
||||||
raise Exception("Only CheckpointConfigBase models are currently supported here.")
|
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||||
config_path = legacy_config_path.as_posix()
|
config_path = legacy_config_path.as_posix()
|
||||||
with open(config_path, "r") as stream:
|
with open(config_path, "r") as stream:
|
||||||
try:
|
|
||||||
flux_conf = yaml.safe_load(stream)
|
flux_conf = yaml.safe_load(stream)
|
||||||
except:
|
|
||||||
raise
|
|
||||||
|
|
||||||
match submodel_type:
|
match submodel_type:
|
||||||
case SubModelType.Transformer:
|
case SubModelType.Transformer:
|
||||||
return self._load_from_singlefile(config, flux_conf)
|
return self._load_from_singlefile(config, flux_conf)
|
||||||
|
|
||||||
raise Exception("Only Transformer submodels are currently supported.")
|
raise ValueError("Only Transformer submodels are currently supported.")
|
||||||
|
|
||||||
def _load_from_singlefile(
|
def _load_from_singlefile(
|
||||||
self,
|
self,
|
||||||
@ -170,7 +162,6 @@ class FluxCheckpointModel(ModelLoader):
|
|||||||
flux_conf: Any,
|
flux_conf: Any,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
assert isinstance(config, MainCheckpointConfig)
|
assert isinstance(config, MainCheckpointConfig)
|
||||||
load_class = Flux
|
|
||||||
params = None
|
params = None
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||||
@ -178,7 +169,7 @@ class FluxCheckpointModel(ModelLoader):
|
|||||||
params = FluxParams(**filtered_data)
|
params = FluxParams(**filtered_data)
|
||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model = load_class(params)
|
model = Flux(params)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, strict=False, assign=True)
|
model.load_state_dict(sd, strict=False, assign=True)
|
||||||
return model
|
return model
|
||||||
@ -194,20 +185,17 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if not isinstance(config, CheckpointConfigBase):
|
if not isinstance(config, CheckpointConfigBase):
|
||||||
raise Exception("Only CheckpointConfigBase models are currently supported here.")
|
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||||
config_path = legacy_config_path.as_posix()
|
config_path = legacy_config_path.as_posix()
|
||||||
with open(config_path, "r") as stream:
|
with open(config_path, "r") as stream:
|
||||||
try:
|
|
||||||
flux_conf = yaml.safe_load(stream)
|
flux_conf = yaml.safe_load(stream)
|
||||||
except:
|
|
||||||
raise
|
|
||||||
|
|
||||||
match submodel_type:
|
match submodel_type:
|
||||||
case SubModelType.Transformer:
|
case SubModelType.Transformer:
|
||||||
return self._load_from_singlefile(config, flux_conf)
|
return self._load_from_singlefile(config, flux_conf)
|
||||||
|
|
||||||
raise Exception("Only Transformer submodels are currently supported.")
|
raise ValueError("Only Transformer submodels are currently supported.")
|
||||||
|
|
||||||
def _load_from_singlefile(
|
def _load_from_singlefile(
|
||||||
self,
|
self,
|
||||||
@ -215,7 +203,6 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|||||||
flux_conf: Any,
|
flux_conf: Any,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
||||||
load_class = Flux
|
|
||||||
params = None
|
params = None
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||||
@ -224,7 +211,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
model = load_class(params)
|
model = Flux(params)
|
||||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, strict=False, assign=True)
|
model.load_state_dict(sd, strict=False, assign=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user