Various styling and exception type updates

This commit is contained in:
Brandon Rising 2024-08-21 11:59:04 -04:00 committed by Brandon
parent 87b7a2e39b
commit df9445c351
3 changed files with 32 additions and 48 deletions

View File

@ -183,7 +183,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
model_key = self.model.key
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
raise ValueError(f"Unknown model: {model_key}")
transformer = self._get_model(context, SubModelType.Transformer)
tokenizer = self._get_model(context, SubModelType.Tokenizer)
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
@ -203,10 +203,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
try:
flux_conf = yaml.safe_load(stream)
except:
raise
flux_conf = yaml.safe_load(stream)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),

View File

@ -484,7 +484,7 @@ class ModelsInterface(InvocationContextInterface):
ModelInstallJob object defining the install job to be used in tracking the job
"""
if not model_path.exists():
raise Exception("Models provided to import_local_model must already exist on disk")
raise ValueError(f"Models provided to import_local_model must already exist on disk at {model_path.as_posix()}")
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, inplace=inplace)
def load_local_model(

View File

@ -49,29 +49,24 @@ class FluxVAELoader(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, VAECheckpointConfig):
model_path = Path(config.path)
load_class = AutoEncoder
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
try:
flux_conf = yaml.safe_load(stream)
except:
raise
if not isinstance(config, VAECheckpointConfig):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = AutoEncoderParams(**filtered_data)
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = AutoEncoderParams(**filtered_data)
with SilenceWarnings():
model = load_class(params)
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
with SilenceWarnings():
model = AutoEncoder(params)
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
return model
else:
return super()._load_model(config, submodel_type)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
@ -84,7 +79,7 @@ class ClipCheckpointModel(ModelLoader):
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CLIPEmbedDiffusersConfig):
raise Exception("Only CLIPEmbedDiffusersConfig models are currently supported here.")
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer:
@ -92,7 +87,7 @@ class ClipCheckpointModel(ModelLoader):
case SubModelType.TextEncoder:
return CLIPTextModel.from_pretrained(config.path)
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
@ -105,7 +100,7 @@ class T5Encoder8bCheckpointModel(ModelLoader):
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5Encoder8bConfig):
raise Exception("Only T5Encoder8bConfig models are currently supported here.")
raise ValueError("Only T5Encoder8bConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
@ -113,7 +108,7 @@ class T5Encoder8bCheckpointModel(ModelLoader):
case SubModelType.TextEncoder2:
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
@ -126,7 +121,7 @@ class T5EncoderCheckpointModel(ModelLoader):
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderConfig):
raise Exception("Only T5EncoderConfig models are currently supported here.")
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
@ -136,7 +131,7 @@ class T5EncoderCheckpointModel(ModelLoader):
Path(config.path) / "text_encoder_2"
) # TODO: Fix hf subfolder install
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ -149,20 +144,17 @@ class FluxCheckpointModel(ModelLoader):
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise Exception("Only CheckpointConfigBase models are currently supported here.")
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
try:
flux_conf = yaml.safe_load(stream)
except:
raise
flux_conf = yaml.safe_load(stream)
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)
raise Exception("Only Transformer submodels are currently supported.")
raise ValueError("Only Transformer submodels are currently supported.")
def _load_from_singlefile(
self,
@ -170,7 +162,6 @@ class FluxCheckpointModel(ModelLoader):
flux_conf: Any,
) -> AnyModel:
assert isinstance(config, MainCheckpointConfig)
load_class = Flux
params = None
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
@ -178,7 +169,7 @@ class FluxCheckpointModel(ModelLoader):
params = FluxParams(**filtered_data)
with SilenceWarnings():
model = load_class(params)
model = Flux(params)
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
return model
@ -194,20 +185,17 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise Exception("Only CheckpointConfigBase models are currently supported here.")
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
try:
flux_conf = yaml.safe_load(stream)
except:
raise
flux_conf = yaml.safe_load(stream)
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)
raise Exception("Only Transformer submodels are currently supported.")
raise ValueError("Only Transformer submodels are currently supported.")
def _load_from_singlefile(
self,
@ -215,7 +203,6 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
flux_conf: Any,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
load_class = Flux
params = None
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
@ -224,7 +211,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
with SilenceWarnings():
with accelerate.init_empty_weights():
model = load_class(params)
model = Flux(params)
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)