added main templates

This commit is contained in:
Lincoln Stein
2023-08-20 21:34:43 -04:00
12 changed files with 223 additions and 21 deletions

View File

@ -1,14 +1,16 @@
"""
Initialization file for invokeai.backend.model_manager.config
"""
from invokeai.backend.model_manager.config import ( # noqa F401
ModelConfigFactory,
ModelConfigBase,
InvalidModelConfigException,
from ..model_management.models.base import read_checkpoint_meta # noqa F401
from .config import ( # noqa F401
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
InvalidModelConfigException,
ModelConfigBase,
ModelConfigFactory,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from .model_install import ModelInstall # noqa F401

View File

@ -50,7 +50,6 @@ class ModelType(str, Enum):
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
class SubModelType(str, Enum):
"""Submodel type."""

View File

@ -0,0 +1,106 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Install/delete models.
Typical usage:
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import ModelInstall
from invokeai.backend.model_manager.storage import ModelConfigStoreSQL
config = InvokeAIAppConfig.get_config()
store = ModelConfigStoreSQL(config.db_path)
installer = ModelInstall(store=store, config=config)
# register config, don't move path
id: str = installer.register_model('/path/to/model')
# register config, and install model in `models`
id: str = installer.install_model('/path/to/model')
# unregister, don't delete
installer.forget(id)
# unregister and delete model from disk
installer.delete_model(id)
# scan directory recursively and install all new models found
ids: List[str] = installer.scan_directory('/path/to/directory')
# unregister any model whose path is no longer valid
ids: List[str] = installer.garbage_collect()
hash: str = installer.hash('/path/to/model') # should be same as id above
The following exceptions may be raised:
DuplicateModelException
UnknownModelTypeException
"""
from pathlib import Path
from typing import Optional, List
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .storage import ModelConfigStore
class ModelInstall(object):
"""Model installer class handles installation from a local path."""
_config: InvokeAIAppConfig
_logger: InvokeAILogger
_store: ModelConfigStore
def __init__(self,
store: Optional[ModelConfigStore] = None,
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[InvokeAILogger] = None
):
"""
Create ModelInstall object.
:param store: Optional ModelConfigStore. If None passed,
defaults to `configs/models.yaml`.
:param config: Optional InvokeAIAppConfig. If None passed,
uses the system-wide default app config.
:param logger: Optional InvokeAILogger. If None passed,
uses the system-wide default logger.
"""
self._config = config or InvokeAIAppConfig.get_config()
self._logger = logger or InvokeAILogger.getLogger()
if store is None:
from .storage import ModelConfigStoreYAML
store = ModelConfigStoreYAML(config.model_conf_path)
self._store = store
def register(self, model_path: Path) -> str:
"""Probe and register the model at model_path."""
pass
def install(self, model_path: Path) -> str:
"""Probe, register and Install the model in the models directory."""
pass
def forget(self, id: str) -> str:
"""Unregister the model identified by id."""
pass
def delete(self, id: str) -> str:
"""
Unregister and delete the model identified by id.
Note that this deletes the model unconditionally.
"""
pass
def scan_directory(self, scan_dir: Path, install: bool=False) -> List[str]:
"""Scan directory for new models and register or install them."""
pass
def garbage_collect(self):
"""Unregister any models whose paths are no longer valid."""
pass
def hash(self, model_path: Path) -> str:
"""Compute the fast hash of the model."""
pass

View File

@ -0,0 +1,55 @@
# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team
"""
Module for probing a Stable Diffusion model and returning
its base type, model type, format and variant.
"""
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Callable
import torch
import safetensors.torch
from invokeai.backend.model_management.models.base import (
read_checkpoint_meta
)
import invokeai.configs.model_probe_templates as templates
from .config import (
ModelType,
BaseModelType,
ModelVariantType,
ModelFormat,
SchedulerPredictionType
)
@dataclass
class ModelProbeInfo(object):
model_type: ModelType
base_type: BaseModelType
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
format: ModelFormat
class ModelProbe(object):
"""
Class to probe a checkpoint, safetensors or diffusers folder.
"""
def __init__(self):
pass
@classmethod
def heuristic_probe(
cls,
model: Path,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> ModelProbeInfo:
"""
Probe model located at path and return ModelProbeInfo object.
A Callable may be passed to return the SchedulerPredictionType.
"""
pass

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -9,11 +9,35 @@ import json
from pathlib import Path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
from invokeai.backend.model_manager import(
read_checkpoint_meta,
ModelType,
ModelVariantType,
BaseModelType,
)
parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model")
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
parser = argparse.ArgumentParser(
description="Create a .json template from checkpoint/safetensors model",
)
parser.add_argument('checkpoint', type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file")
parser.add_argument("--base-type",
type=str,
choices=[x.value for x in BaseModelType],
help="Base model",
)
parser.add_argument("--model-type",
type=str,
choices=[x.value for x in ModelType],
default='main',
help="Type of the model",
)
parser.add_argument("--variant",
type=str,
choices=[x.value for x in ModelVariantType],
default='normal',
help="Base type of the model",
)
opt = parser.parse_args()
ckpt = read_checkpoint_meta(opt.checkpoint)
@ -25,9 +49,16 @@ tmpl = {}
for key, tensor in ckpt.items():
tmpl[key] = list(tensor.shape)
meta = {
'base_type': opt.base_type,
'model_type': opt.model_type,
'variant': opt.variant,
'template': tmpl
}
try:
with open(opt.template, "w") as f:
json.dump(tmpl, f)
with open(opt.template, "w", encoding="utf-8") as f:
json.dump(meta, f)
print(f"Template written out as {opt.template}")
except Exception as e:
except OSError as e:
print(f"An exception occurred while writing template: {str(e)}")

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python
"""
Read a checkpoint/safetensors file and compare it to a template .json.
Returns True if their metadata match.
"""
@ -26,12 +27,14 @@ checkpoint_metadata = {}
for key, tensor in ckpt.items():
checkpoint_metadata[key] = list(tensor.shape)
with open(opt.template, "r") as f:
with open(opt.template, "r", encoding="utf-8") as f:
template = json.load(f)
if checkpoint_metadata == template:
print("True")
sys.exit(0)
else:
print("False")
sys.exit(-1)
for key in template["template"]:
val1 = checkpoint_metadata.get(key)
val2 = template["template"][key]
if val1 != val2:
print(f"mismatch: {key}: template={val2} != checkpoint={val1}")
sys.exit(-1)
print("Match")
sys.exit(0)