This commit is contained in:
Lincoln Stein
2023-08-23 19:53:21 -04:00
parent 055ad0101d
commit 93cef55964
13 changed files with 59 additions and 60 deletions

View File

@ -182,6 +182,7 @@ DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_MAX_VRAM = 0.5
class InvokeAIAppConfig(InvokeAISettings):
"""
Generate images using Stable Diffusion. Use "invokeai" to launch

View File

@ -1,7 +1,7 @@
"""
Initialization file for invokeai.backend.model_manager.config
"""
from ..model_management.models.base import read_checkpoint_meta # noqa F401
from ..model_management.models.base import read_checkpoint_meta # noqa F401
from .config import ( # noqa F401
BaseModelType,
InvalidModelConfigException,

View File

@ -30,6 +30,7 @@ from pydantic.error_wrappers import ValidationError
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
class BaseModelType(str, Enum):
"""Base model type."""
@ -50,6 +51,7 @@ class ModelType(str, Enum):
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
class SubModelType(str, Enum):
"""Submodel type."""
@ -172,12 +174,10 @@ class MainCheckpointConfig(CheckpointConfig, MainConfig):
"""Model config for main checkpoint models."""
class MainDiffusersConfig(DiffusersConfig, MainConfig):
"""Model config for main diffusers models."""
class ONNXSD1Config(MainConfig):
"""Model config for ONNX format models based on sd-1."""

View File

@ -55,7 +55,7 @@ class FastModelHash(object):
for file in files:
# Ignore the config files, which change locally,
# and just look at the bin files.
if file in ['config.json', 'model_index.json']:
if file in ["config.json", "model_index.json"]:
continue
path = Path(root) / file
fast_hash = cls._hash_file(path)

View File

@ -59,11 +59,12 @@ class ModelInstallBase(ABC):
"""Abstract base class for InvokeAI model installation"""
@abstractmethod
def __init__(self,
store: Optional[ModelConfigStore] = None,
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[InvokeAILogger] = None
):
def __init__(
self,
store: Optional[ModelConfigStore] = None,
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[InvokeAILogger] = None,
):
"""
Create ModelInstall object.
@ -190,11 +191,12 @@ class ModelInstall(ModelInstallBase):
},
}
def __init__(self,
store: Optional[ModelConfigStore] = None,
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[InvokeAILogger] = None
): # noqa D107 - use base class docstrings
def __init__(
self,
store: Optional[ModelConfigStore] = None,
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[InvokeAILogger] = None,
): # noqa D107 - use base class docstrings
self._config = config or InvokeAIAppConfig.get_config()
self._logger = logger or InvokeAILogger.getLogger()
if store is None:
@ -213,7 +215,7 @@ class ModelInstall(ModelInstallBase):
name=model_path.stem,
base_model=info.base_type,
model_type=info.model_type,
model_format=info.format
model_format=info.format,
)
# add 'main' specific fields
if info.model_type == ModelType.Main and info.format == ModelFormat.Checkpoint:
@ -227,7 +229,7 @@ class ModelInstall(ModelInstallBase):
self._store.add_model(id, registration_data)
return id
def install(self, model_path: Union[Path, str]) -> str: # noqa D102
def install(self, model_path: Union[Path, str]) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = ModelProbe.probe(model_path)
dest_path = self._config.models_path / info.base_model.value / info.model_type.value / model_path.name
@ -243,22 +245,22 @@ class ModelInstall(ModelInstallBase):
info,
)
def unregister(self, id: str): # noqa D102
def unregister(self, id: str): # noqa D102
self._store.del_model(id)
def delete(self, id: str): # noqa D102
def delete(self, id: str): # noqa D102
model = self._store.get_model(id)
rmtree(model.path)
self.unregister(id)
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
search = ModelSearch()
search.model_found = self._scan_install if install else self._scan_register
self._installed = set()
search.search([scan_dir])
return list(self._installed)
def garbage_collect(self) -> List[str]: # noqa D102
def garbage_collect(self) -> List[str]: # noqa D102
unregistered = list()
for model in self._store.all_models():
path = Path(model.path)
@ -267,7 +269,7 @@ class ModelInstall(ModelInstallBase):
unregistered.append(model.id)
return unregistered
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
return FastModelHash.hash(model_path)
# the following two methods are callbacks to the ModelSearch object

View File

@ -30,6 +30,7 @@ from .util import SilenceWarnings, lora_token_vector_length
class InvalidModelException(Exception):
"""Raised when an invalid model is encountered."""
@dataclass
class ModelProbeInfo(object):
"""Fields describing a probed model."""
@ -49,9 +50,9 @@ class ModelProbeBase(ABC):
@classmethod
@abstractmethod
def probe(
cls,
model: Path,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
cls,
model: Path,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> Optional[ModelProbeInfo]:
"""
Probe model located at path and return ModelProbeInfo object.
@ -62,6 +63,7 @@ class ModelProbeBase(ABC):
"""
pass
class ProbeBase(ABC):
"""Base model for probing checkpoint and diffusers-style models."""
@ -102,9 +104,7 @@ class ModelProbe(ModelProbeBase):
}
@classmethod
def register_probe(
cls, format: ModelFormat, model_type: ModelType, probe_class: ProbeBase
):
def register_probe(cls, format: ModelFormat, model_type: ModelType, probe_class: ProbeBase):
"""
Register a probe subclass to use when interrogating a model.
@ -123,13 +123,9 @@ class ModelProbe(ModelProbeBase):
"""Probe model."""
try:
model_type = (
cls.get_model_type_from_folder(model)
if model.is_dir()
else cls.get_model_type_from_checkpoint(model)
cls.get_model_type_from_folder(model) if model.is_dir() else cls.get_model_type_from_checkpoint(model)
)
format_type = "onnx" if model_type == ModelType.ONNX \
else "diffusers" if model.is_dir() \
else "checkpoint"
format_type = "onnx" if model_type == ModelType.ONNX else "diffusers" if model.is_dir() else "checkpoint"
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
@ -252,10 +248,12 @@ class ModelProbe(ModelProbeBase):
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
# ##################################################3
# Checkpoint probing
# ##################################################3
class CheckpointProbeBase(ProbeBase):
"""Base class for probing checkpoint-style models."""

View File

@ -132,7 +132,7 @@ class ModelSearch(ModelSearchBase):
if self.search_completed:
self.search_completed(self._model_set)
def list_models(self, directories: List[Union[Path,str]]) -> List[Path]:
def list_models(self, directories: List[Union[Path, str]]) -> List[Path]:
"""Return list of models found"""
self.search(directories)
return list(self._model_set)

View File

@ -1,6 +1,6 @@
"""
Initialization file for invokeai.backend.model_manager.storage
"""
from .base import ModelConfigStore, UnknownModelException, DuplicateModelException # noqa F401
from .base import ModelConfigStore, UnknownModelException, DuplicateModelException # noqa F401
from .yaml import ModelConfigStoreYAML # noqa F401
from .sql import ModelConfigStoreSQL # noqa F401

View File

@ -21,6 +21,7 @@ class InvalidModelException(Exception):
class UnknownModelException(Exception):
"""Raised on an attempt to delete a model with a nonexistent key."""
class ModelConfigStore(ABC):
"""Abstract base class for storage and retrieval of model configs."""

View File

@ -307,11 +307,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
config=?
WHERE id=?;
""",
(record.base_model,
record.model_type,
record.name,
record.path,
json_serialized, key),
(record.base_model, record.model_type, record.name, record.path, json_serialized, key),
)
if self._cursor.rowcount < 1:
raise UnknownModelException

View File

@ -12,6 +12,7 @@ from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
from picklescan.scanner import scan_file_path
class SilenceWarnings(object):
"""
Context manager that silences warnings from transformers and diffusers.
@ -111,6 +112,7 @@ def lora_token_vector_length(checkpoint: dict) -> Optional[int]:
return lora_token_vector_length
def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
@ -142,6 +144,7 @@ def _fast_safetensors_reader(path: str):
return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"):
try:

View File

@ -43,14 +43,12 @@ class CreateTemplateScanner(ModelSearch):
def write_template(self, model: Path, info: ModelProbeInfo):
"""Write template for a checkpoint file."""
dest_path = Path(self._dest,
"checkpoints" if model.is_file() else 'diffusers',
info.base_type.value,
info.model_type.value
)
template: dict = self._make_checkpoint_template(model) \
if model.is_file() \
else self._make_diffusers_template(model)
dest_path = Path(
self._dest, "checkpoints" if model.is_file() else "diffusers", info.base_type.value, info.model_type.value
)
template: dict = (
self._make_checkpoint_template(model) if model.is_file() else self._make_diffusers_template(model)
)
if not template:
print(f"Could not create template for {model}, got {template}")
return
@ -105,7 +103,7 @@ class CreateTemplateScanner(ModelSearch):
tmpl = None
if (model / "model_index.json").exists(): # a pipeline
tmpl = {}
for subdir in ['unet', 'text_encoder', 'vae', 'text_encoder_2']:
for subdir in ["unet", "text_encoder", "vae", "text_encoder_2"]:
config = model / subdir / "config.json"
try:
tmpl[subdir] = self._read_config(config)
@ -122,7 +120,7 @@ class CreateTemplateScanner(ModelSearch):
return tmpl
def _read_config(self, config: Path) -> dict:
with open(config, 'r', encoding='utf-8') as f:
with open(config, "r", encoding="utf-8") as f:
return {x: y for x, y in json.load(f).items() if not x.startswith("_")}
def on_search_completed(self):
@ -137,16 +135,14 @@ class CreateTemplateScanner(ModelSearch):
parser = argparse.ArgumentParser(
description="Scan the provided path recursively and create .json templates for all models found.",
)
parser.add_argument("--scan",
type=Path,
help="Path to recursively scan for models"
)
parser.add_argument("--out",
type=Path,
dest="outdir",
default=Path(__file__).resolve().parents[1] / "invokeai/configs/model_probe_templates",
help="Destination for templates",
)
parser.add_argument("--scan", type=Path, help="Path to recursively scan for models")
parser.add_argument(
"--out",
type=Path,
dest="outdir",
default=Path(__file__).resolve().parents[1] / "invokeai/configs/model_probe_templates",
help="Destination for templates",
)
opt = parser.parse_args()
scanner = CreateTemplateScanner([opt.scan], dest=opt.outdir)

View File

@ -11,10 +11,12 @@ from invokeai.backend.model_manager import (
InvalidModelException,
)
def helper(model_path: Path):
print('Warning: guessing "v_prediction" SchedulerPredictionType', file=sys.stderr)
return SchedulerPredictionType.VPrediction
parser = argparse.ArgumentParser(description="Probe model type")
parser.add_argument(
"model_path",