Remove automatic install of models during flux model loader, remove no longer used import function on context

This commit is contained in:
Brandon Rising 2024-08-21 15:34:34 -04:00 committed by Brandon
parent a0a259eef1
commit f130ddec7c
2 changed files with 13 additions and 66 deletions

View File

@ -141,21 +141,6 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model)
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
"base": {
"repo": "InvokeAI/flux_schnell::t5_xxl_encoder/base",
"name": "t5_base_encoder",
"format": ModelFormat.T5Encoder,
},
"8b_quantized": {
"repo": "invokeai/flux_schnell::t5_xxl_encoder/optimum_quanto_qfloat8",
"name": "t5_8b_quantized_encoder",
"format": ModelFormat.T5Encoder8b,
},
}
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
@ -196,15 +181,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
vae = self._install_model(
context,
SubModelType.VAE,
"FLUX.1-schnell_ae",
"black-forest-labs/FLUX.1-schnell::ae.safetensors",
ModelFormat.Checkpoint,
ModelType.VAE,
BaseModelType.Flux,
)
vae = self._get_model(context, SubModelType.VAE)
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
@ -224,36 +201,38 @@ class FluxModelLoaderInvocation(BaseInvocation):
match submodel:
case SubModelType.Transformer:
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
case SubModelType.VAE:
return self._pull_model_from_mm(
context,
SubModelType.VAE,
"FLUX.1-schnell_ae",
ModelType.VAE,
BaseModelType.Flux,
)
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
return self._install_model(
return self._pull_model_from_mm(
context,
submodel,
"clip-vit-large-patch14",
"openai/clip-vit-large-patch14",
ModelFormat.Diffusers,
ModelType.CLIPEmbed,
BaseModelType.Any,
)
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
return self._install_model(
return self._pull_model_from_mm(
context,
submodel,
self.t5_encoder.name,
"",
ModelFormat.T5Encoder,
ModelType.T5Encoder,
BaseModelType.Any,
)
case _:
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
def _install_model(
def _pull_model_from_mm(
self,
context: InvocationContext,
submodel: SubModelType,
name: str,
repo_id: str,
format: ModelFormat,
type: ModelType,
base: BaseModelType,
):
@ -262,16 +241,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
raise Exception(f"Multiple models detected for selected model with name {name}")
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
else:
model_path = context.models.download_and_cache_model(repo_id)
config = ModelRecordChanges(name=name, base=base, type=type, format=format)
model_install_job = context.models.import_local_model(model_path=model_path, config=config)
while not model_install_job.in_terminal_state:
sleep(0.01)
if not model_install_job.config_out:
raise Exception(f"Failed to install {name}")
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(
update={"submodel_type": submodel}
)
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
@invocation(

View File

@ -464,29 +464,6 @@ class ModelsInterface(InvocationContextInterface):
"""
return self._services.model_manager.install.download_and_cache_model(source=source)
def import_local_model(
self,
model_path: Path,
config: Optional[ModelRecordChanges] = None,
inplace: Optional[bool] = False,
):
"""
Import the model file located at the given local file path and return its ModelInstallJob.
This can be used to single-file models or directories.
Args:
model_path: A pathlib.Path object pointing to a model file or directory
config: Optional ModelRecordChanges to define manual probe overrides
inplace: Optional boolean to declare whether or not to install the model in the models dir
Returns:
ModelInstallJob object defining the install job to be used in tracking the job
"""
if not model_path.exists():
raise ValueError(f"Models provided to import_local_model must already exist on disk at {model_path.as_posix()}")
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, inplace=inplace)
def load_local_model(
self,
model_path: Path,