mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove automatic install of models during flux model loader, remove no longer used import function on context
This commit is contained in:
parent
a0a259eef1
commit
f130ddec7c
@ -141,21 +141,6 @@ class ModelIdentifierInvocation(BaseInvocation):
|
|||||||
return ModelIdentifierOutput(model=self.model)
|
return ModelIdentifierOutput(model=self.model)
|
||||||
|
|
||||||
|
|
||||||
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
|
|
||||||
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
|
|
||||||
"base": {
|
|
||||||
"repo": "InvokeAI/flux_schnell::t5_xxl_encoder/base",
|
|
||||||
"name": "t5_base_encoder",
|
|
||||||
"format": ModelFormat.T5Encoder,
|
|
||||||
},
|
|
||||||
"8b_quantized": {
|
|
||||||
"repo": "invokeai/flux_schnell::t5_xxl_encoder/optimum_quanto_qfloat8",
|
|
||||||
"name": "t5_8b_quantized_encoder",
|
|
||||||
"format": ModelFormat.T5Encoder8b,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("flux_model_loader_output")
|
@invocation_output("flux_model_loader_output")
|
||||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""Flux base model loader output"""
|
"""Flux base model loader output"""
|
||||||
@ -196,15 +181,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
||||||
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
|
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
|
||||||
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
|
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
|
||||||
vae = self._install_model(
|
vae = self._get_model(context, SubModelType.VAE)
|
||||||
context,
|
|
||||||
SubModelType.VAE,
|
|
||||||
"FLUX.1-schnell_ae",
|
|
||||||
"black-forest-labs/FLUX.1-schnell::ae.safetensors",
|
|
||||||
ModelFormat.Checkpoint,
|
|
||||||
ModelType.VAE,
|
|
||||||
BaseModelType.Flux,
|
|
||||||
)
|
|
||||||
transformer_config = context.models.get_config(transformer)
|
transformer_config = context.models.get_config(transformer)
|
||||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||||
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
|
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
|
||||||
@ -224,36 +201,38 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
match submodel:
|
match submodel:
|
||||||
case SubModelType.Transformer:
|
case SubModelType.Transformer:
|
||||||
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||||
|
case SubModelType.VAE:
|
||||||
|
return self._pull_model_from_mm(
|
||||||
|
context,
|
||||||
|
SubModelType.VAE,
|
||||||
|
"FLUX.1-schnell_ae",
|
||||||
|
ModelType.VAE,
|
||||||
|
BaseModelType.Flux,
|
||||||
|
)
|
||||||
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
|
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
|
||||||
return self._install_model(
|
return self._pull_model_from_mm(
|
||||||
context,
|
context,
|
||||||
submodel,
|
submodel,
|
||||||
"clip-vit-large-patch14",
|
"clip-vit-large-patch14",
|
||||||
"openai/clip-vit-large-patch14",
|
|
||||||
ModelFormat.Diffusers,
|
|
||||||
ModelType.CLIPEmbed,
|
ModelType.CLIPEmbed,
|
||||||
BaseModelType.Any,
|
BaseModelType.Any,
|
||||||
)
|
)
|
||||||
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
|
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
|
||||||
return self._install_model(
|
return self._pull_model_from_mm(
|
||||||
context,
|
context,
|
||||||
submodel,
|
submodel,
|
||||||
self.t5_encoder.name,
|
self.t5_encoder.name,
|
||||||
"",
|
|
||||||
ModelFormat.T5Encoder,
|
|
||||||
ModelType.T5Encoder,
|
ModelType.T5Encoder,
|
||||||
BaseModelType.Any,
|
BaseModelType.Any,
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
|
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
|
||||||
|
|
||||||
def _install_model(
|
def _pull_model_from_mm(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
submodel: SubModelType,
|
submodel: SubModelType,
|
||||||
name: str,
|
name: str,
|
||||||
repo_id: str,
|
|
||||||
format: ModelFormat,
|
|
||||||
type: ModelType,
|
type: ModelType,
|
||||||
base: BaseModelType,
|
base: BaseModelType,
|
||||||
):
|
):
|
||||||
@ -262,16 +241,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
raise Exception(f"Multiple models detected for selected model with name {name}")
|
raise Exception(f"Multiple models detected for selected model with name {name}")
|
||||||
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
|
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
|
||||||
else:
|
else:
|
||||||
model_path = context.models.download_and_cache_model(repo_id)
|
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
|
||||||
config = ModelRecordChanges(name=name, base=base, type=type, format=format)
|
|
||||||
model_install_job = context.models.import_local_model(model_path=model_path, config=config)
|
|
||||||
while not model_install_job.in_terminal_state:
|
|
||||||
sleep(0.01)
|
|
||||||
if not model_install_job.config_out:
|
|
||||||
raise Exception(f"Failed to install {name}")
|
|
||||||
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(
|
|
||||||
update={"submodel_type": submodel}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
|
@ -464,29 +464,6 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
"""
|
"""
|
||||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||||
|
|
||||||
def import_local_model(
|
|
||||||
self,
|
|
||||||
model_path: Path,
|
|
||||||
config: Optional[ModelRecordChanges] = None,
|
|
||||||
inplace: Optional[bool] = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Import the model file located at the given local file path and return its ModelInstallJob.
|
|
||||||
|
|
||||||
This can be used to single-file models or directories.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: A pathlib.Path object pointing to a model file or directory
|
|
||||||
config: Optional ModelRecordChanges to define manual probe overrides
|
|
||||||
inplace: Optional boolean to declare whether or not to install the model in the models dir
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ModelInstallJob object defining the install job to be used in tracking the job
|
|
||||||
"""
|
|
||||||
if not model_path.exists():
|
|
||||||
raise ValueError(f"Models provided to import_local_model must already exist on disk at {model_path.as_posix()}")
|
|
||||||
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, inplace=inplace)
|
|
||||||
|
|
||||||
def load_local_model(
|
def load_local_model(
|
||||||
self,
|
self,
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
|
Loading…
Reference in New Issue
Block a user