Various styling and exception type updates

This commit is contained in:
Brandon Rising 2024-08-21 11:59:04 -04:00 committed by Brandon
parent 87b7a2e39b
commit df9445c351
3 changed files with 32 additions and 48 deletions

View File

@ -183,7 +183,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
model_key = self.model.key model_key = self.model.key
if not context.models.exists(model_key): if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}") raise ValueError(f"Unknown model: {model_key}")
transformer = self._get_model(context, SubModelType.Transformer) transformer = self._get_model(context, SubModelType.Transformer)
tokenizer = self._get_model(context, SubModelType.Tokenizer) tokenizer = self._get_model(context, SubModelType.Tokenizer)
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2) tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
@ -203,10 +203,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
config_path = legacy_config_path.as_posix() config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream: with open(config_path, "r") as stream:
try:
flux_conf = yaml.safe_load(stream) flux_conf = yaml.safe_load(stream)
except:
raise
return FluxModelLoaderOutput( return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer), transformer=TransformerField(transformer=transformer),

View File

@ -484,7 +484,7 @@ class ModelsInterface(InvocationContextInterface):
ModelInstallJob object defining the install job to be used in tracking the job ModelInstallJob object defining the install job to be used in tracking the job
""" """
if not model_path.exists(): if not model_path.exists():
raise Exception("Models provided to import_local_model must already exist on disk") raise ValueError(f"Models provided to import_local_model must already exist on disk at {model_path.as_posix()}")
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, inplace=inplace) return self._services.model_manager.install.heuristic_import(str(model_path), config=config, inplace=inplace)
def load_local_model( def load_local_model(

View File

@ -49,29 +49,24 @@ class FluxVAELoader(ModelLoader):
config: AnyModelConfig, config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
if isinstance(config, VAECheckpointConfig): if not isinstance(config, VAECheckpointConfig):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path) model_path = Path(config.path)
load_class = AutoEncoder
legacy_config_path = app_config.legacy_conf_path / config.config_path legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix() config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream: with open(config_path, "r") as stream:
try:
flux_conf = yaml.safe_load(stream) flux_conf = yaml.safe_load(stream)
except:
raise
dataclass_fields = {f.name for f in fields(AutoEncoderParams)} dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields} filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = AutoEncoderParams(**filtered_data) params = AutoEncoderParams(**filtered_data)
with SilenceWarnings(): with SilenceWarnings():
model = load_class(params) model = AutoEncoder(params)
sd = load_file(model_path) sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True) model.load_state_dict(sd, strict=False, assign=True)
return model return model
else:
return super()._load_model(config, submodel_type)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
@ -84,7 +79,7 @@ class ClipCheckpointModel(ModelLoader):
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
if not isinstance(config, CLIPEmbedDiffusersConfig): if not isinstance(config, CLIPEmbedDiffusersConfig):
raise Exception("Only CLIPEmbedDiffusersConfig models are currently supported here.") raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
match submodel_type: match submodel_type:
case SubModelType.Tokenizer: case SubModelType.Tokenizer:
@ -92,7 +87,7 @@ class ClipCheckpointModel(ModelLoader):
case SubModelType.TextEncoder: case SubModelType.TextEncoder:
return CLIPTextModel.from_pretrained(config.path) return CLIPTextModel.from_pretrained(config.path)
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.") raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
@ -105,7 +100,7 @@ class T5Encoder8bCheckpointModel(ModelLoader):
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
if not isinstance(config, T5Encoder8bConfig): if not isinstance(config, T5Encoder8bConfig):
raise Exception("Only T5Encoder8bConfig models are currently supported here.") raise ValueError("Only T5Encoder8bConfig models are currently supported here.")
match submodel_type: match submodel_type:
case SubModelType.Tokenizer2: case SubModelType.Tokenizer2:
@ -113,7 +108,7 @@ class T5Encoder8bCheckpointModel(ModelLoader):
case SubModelType.TextEncoder2: case SubModelType.TextEncoder2:
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2") return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.") raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
@ -126,7 +121,7 @@ class T5EncoderCheckpointModel(ModelLoader):
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
if not isinstance(config, T5EncoderConfig): if not isinstance(config, T5EncoderConfig):
raise Exception("Only T5EncoderConfig models are currently supported here.") raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type: match submodel_type:
case SubModelType.Tokenizer2: case SubModelType.Tokenizer2:
@ -136,7 +131,7 @@ class T5EncoderCheckpointModel(ModelLoader):
Path(config.path) / "text_encoder_2" Path(config.path) / "text_encoder_2"
) # TODO: Fix hf subfolder install ) # TODO: Fix hf subfolder install
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.") raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ -149,20 +144,17 @@ class FluxCheckpointModel(ModelLoader):
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
if not isinstance(config, CheckpointConfigBase): if not isinstance(config, CheckpointConfigBase):
raise Exception("Only CheckpointConfigBase models are currently supported here.") raise ValueError("Only CheckpointConfigBase models are currently supported here.")
legacy_config_path = app_config.legacy_conf_path / config.config_path legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix() config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream: with open(config_path, "r") as stream:
try:
flux_conf = yaml.safe_load(stream) flux_conf = yaml.safe_load(stream)
except:
raise
match submodel_type: match submodel_type:
case SubModelType.Transformer: case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf) return self._load_from_singlefile(config, flux_conf)
raise Exception("Only Transformer submodels are currently supported.") raise ValueError("Only Transformer submodels are currently supported.")
def _load_from_singlefile( def _load_from_singlefile(
self, self,
@ -170,7 +162,6 @@ class FluxCheckpointModel(ModelLoader):
flux_conf: Any, flux_conf: Any,
) -> AnyModel: ) -> AnyModel:
assert isinstance(config, MainCheckpointConfig) assert isinstance(config, MainCheckpointConfig)
load_class = Flux
params = None params = None
model_path = Path(config.path) model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)} dataclass_fields = {f.name for f in fields(FluxParams)}
@ -178,7 +169,7 @@ class FluxCheckpointModel(ModelLoader):
params = FluxParams(**filtered_data) params = FluxParams(**filtered_data)
with SilenceWarnings(): with SilenceWarnings():
model = load_class(params) model = Flux(params)
sd = load_file(model_path) sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True) model.load_state_dict(sd, strict=False, assign=True)
return model return model
@ -194,20 +185,17 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
if not isinstance(config, CheckpointConfigBase): if not isinstance(config, CheckpointConfigBase):
raise Exception("Only CheckpointConfigBase models are currently supported here.") raise ValueError("Only CheckpointConfigBase models are currently supported here.")
legacy_config_path = app_config.legacy_conf_path / config.config_path legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix() config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream: with open(config_path, "r") as stream:
try:
flux_conf = yaml.safe_load(stream) flux_conf = yaml.safe_load(stream)
except:
raise
match submodel_type: match submodel_type:
case SubModelType.Transformer: case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf) return self._load_from_singlefile(config, flux_conf)
raise Exception("Only Transformer submodels are currently supported.") raise ValueError("Only Transformer submodels are currently supported.")
def _load_from_singlefile( def _load_from_singlefile(
self, self,
@ -215,7 +203,6 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
flux_conf: Any, flux_conf: Any,
) -> AnyModel: ) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig) assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
load_class = Flux
params = None params = None
model_path = Path(config.path) model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)} dataclass_fields = {f.name for f in fields(FluxParams)}
@ -224,7 +211,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
with SilenceWarnings(): with SilenceWarnings():
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
model = load_class(params) model = Flux(params)
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path) sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True) model.load_state_dict(sd, strict=False, assign=True)