mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added main templates
This commit is contained in:
@ -1,14 +1,16 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.config
|
||||
"""
|
||||
from invokeai.backend.model_manager.config import ( # noqa F401
|
||||
ModelConfigFactory,
|
||||
ModelConfigBase,
|
||||
InvalidModelConfigException,
|
||||
from ..model_management.models.base import read_checkpoint_meta # noqa F401
|
||||
from .config import ( # noqa F401
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
InvalidModelConfigException,
|
||||
ModelConfigBase,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from .model_install import ModelInstall # noqa F401
|
||||
|
@ -50,7 +50,6 @@ class ModelType(str, Enum):
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
"""Submodel type."""
|
||||
|
||||
|
106
invokeai/backend/model_manager/model_install.py
Normal file
106
invokeai/backend/model_manager/model_install.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Install/delete models.
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import ModelInstall
|
||||
from invokeai.backend.model_manager.storage import ModelConfigStoreSQL
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
store = ModelConfigStoreSQL(config.db_path)
|
||||
installer = ModelInstall(store=store, config=config)
|
||||
|
||||
# register config, don't move path
|
||||
id: str = installer.register_model('/path/to/model')
|
||||
|
||||
# register config, and install model in `models`
|
||||
id: str = installer.install_model('/path/to/model')
|
||||
|
||||
# unregister, don't delete
|
||||
installer.forget(id)
|
||||
|
||||
# unregister and delete model from disk
|
||||
installer.delete_model(id)
|
||||
|
||||
# scan directory recursively and install all new models found
|
||||
ids: List[str] = installer.scan_directory('/path/to/directory')
|
||||
|
||||
# unregister any model whose path is no longer valid
|
||||
ids: List[str] = installer.garbage_collect()
|
||||
|
||||
hash: str = installer.hash('/path/to/model') # should be same as id above
|
||||
|
||||
The following exceptions may be raised:
|
||||
DuplicateModelException
|
||||
UnknownModelTypeException
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from .storage import ModelConfigStore
|
||||
|
||||
|
||||
class ModelInstall(object):
|
||||
"""Model installer class handles installation from a local path."""
|
||||
|
||||
_config: InvokeAIAppConfig
|
||||
_logger: InvokeAILogger
|
||||
_store: ModelConfigStore
|
||||
|
||||
def __init__(self,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
logger: Optional[InvokeAILogger] = None
|
||||
):
|
||||
"""
|
||||
Create ModelInstall object.
|
||||
|
||||
:param store: Optional ModelConfigStore. If None passed,
|
||||
defaults to `configs/models.yaml`.
|
||||
:param config: Optional InvokeAIAppConfig. If None passed,
|
||||
uses the system-wide default app config.
|
||||
:param logger: Optional InvokeAILogger. If None passed,
|
||||
uses the system-wide default logger.
|
||||
"""
|
||||
self._config = config or InvokeAIAppConfig.get_config()
|
||||
self._logger = logger or InvokeAILogger.getLogger()
|
||||
if store is None:
|
||||
from .storage import ModelConfigStoreYAML
|
||||
store = ModelConfigStoreYAML(config.model_conf_path)
|
||||
self._store = store
|
||||
|
||||
|
||||
def register(self, model_path: Path) -> str:
|
||||
"""Probe and register the model at model_path."""
|
||||
pass
|
||||
|
||||
def install(self, model_path: Path) -> str:
|
||||
"""Probe, register and Install the model in the models directory."""
|
||||
pass
|
||||
|
||||
def forget(self, id: str) -> str:
|
||||
"""Unregister the model identified by id."""
|
||||
pass
|
||||
|
||||
def delete(self, id: str) -> str:
|
||||
"""
|
||||
Unregister and delete the model identified by id.
|
||||
Note that this deletes the model unconditionally.
|
||||
"""
|
||||
pass
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool=False) -> List[str]:
|
||||
"""Scan directory for new models and register or install them."""
|
||||
pass
|
||||
|
||||
def garbage_collect(self):
|
||||
"""Unregister any models whose paths are no longer valid."""
|
||||
pass
|
||||
|
||||
def hash(self, model_path: Path) -> str:
|
||||
"""Compute the fast hash of the model."""
|
||||
pass
|
||||
|
55
invokeai/backend/model_manager/probe.py
Normal file
55
invokeai/backend/model_manager/probe.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team
|
||||
"""
|
||||
Module for probing a Stable Diffusion model and returning
|
||||
its base type, model type, format and variant.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Callable
|
||||
|
||||
import torch
|
||||
import safetensors.torch
|
||||
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
read_checkpoint_meta
|
||||
)
|
||||
import invokeai.configs.model_probe_templates as templates
|
||||
|
||||
from .config import (
|
||||
ModelType,
|
||||
BaseModelType,
|
||||
ModelVariantType,
|
||||
ModelFormat,
|
||||
SchedulerPredictionType
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelProbeInfo(object):
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
format: ModelFormat
|
||||
|
||||
class ModelProbe(object):
|
||||
"""
|
||||
Class to probe a checkpoint, safetensors or diffusers folder.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model: Path,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> ModelProbeInfo:
|
||||
"""
|
||||
Probe model located at path and return ModelProbeInfo object.
|
||||
A Callable may be passed to return the SchedulerPredictionType.
|
||||
"""
|
||||
pass
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -9,11 +9,35 @@ import json
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
||||
from invokeai.backend.model_manager import(
|
||||
read_checkpoint_meta,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
BaseModelType,
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model")
|
||||
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create a .json template from checkpoint/safetensors model",
|
||||
)
|
||||
parser.add_argument('checkpoint', type=Path, help="Path to the input checkpoint/safetensors file")
|
||||
parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file")
|
||||
parser.add_argument("--base-type",
|
||||
type=str,
|
||||
choices=[x.value for x in BaseModelType],
|
||||
help="Base model",
|
||||
)
|
||||
parser.add_argument("--model-type",
|
||||
type=str,
|
||||
choices=[x.value for x in ModelType],
|
||||
default='main',
|
||||
help="Type of the model",
|
||||
)
|
||||
parser.add_argument("--variant",
|
||||
type=str,
|
||||
choices=[x.value for x in ModelVariantType],
|
||||
default='normal',
|
||||
help="Base type of the model",
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
ckpt = read_checkpoint_meta(opt.checkpoint)
|
||||
@ -25,9 +49,16 @@ tmpl = {}
|
||||
for key, tensor in ckpt.items():
|
||||
tmpl[key] = list(tensor.shape)
|
||||
|
||||
meta = {
|
||||
'base_type': opt.base_type,
|
||||
'model_type': opt.model_type,
|
||||
'variant': opt.variant,
|
||||
'template': tmpl
|
||||
}
|
||||
|
||||
try:
|
||||
with open(opt.template, "w") as f:
|
||||
json.dump(tmpl, f)
|
||||
with open(opt.template, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f)
|
||||
print(f"Template written out as {opt.template}")
|
||||
except Exception as e:
|
||||
except OSError as e:
|
||||
print(f"An exception occurred while writing template: {str(e)}")
|
||||
|
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Read a checkpoint/safetensors file and compare it to a template .json.
|
||||
|
||||
Returns True if their metadata match.
|
||||
"""
|
||||
|
||||
@ -26,12 +27,14 @@ checkpoint_metadata = {}
|
||||
for key, tensor in ckpt.items():
|
||||
checkpoint_metadata[key] = list(tensor.shape)
|
||||
|
||||
with open(opt.template, "r") as f:
|
||||
with open(opt.template, "r", encoding="utf-8") as f:
|
||||
template = json.load(f)
|
||||
|
||||
if checkpoint_metadata == template:
|
||||
print("True")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("False")
|
||||
sys.exit(-1)
|
||||
for key in template["template"]:
|
||||
val1 = checkpoint_metadata.get(key)
|
||||
val2 = template["template"][key]
|
||||
if val1 != val2:
|
||||
print(f"mismatch: {key}: template={val2} != checkpoint={val1}")
|
||||
sys.exit(-1)
|
||||
print("Match")
|
||||
sys.exit(0)
|
||||
|
Reference in New Issue
Block a user