mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
blackify
This commit is contained in:
@ -182,6 +182,7 @@ DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_MAX_VRAM = 0.5
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.config
|
||||
"""
|
||||
from ..model_management.models.base import read_checkpoint_meta # noqa F401
|
||||
from ..model_management.models.base import read_checkpoint_meta # noqa F401
|
||||
from .config import ( # noqa F401
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
|
@ -30,6 +30,7 @@ from pydantic.error_wrappers import ValidationError
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
"""Base model type."""
|
||||
|
||||
@ -50,6 +51,7 @@ class ModelType(str, Enum):
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
"""Submodel type."""
|
||||
|
||||
@ -172,12 +174,10 @@ class MainCheckpointConfig(CheckpointConfig, MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfig, MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
|
||||
|
||||
class ONNXSD1Config(MainConfig):
|
||||
"""Model config for ONNX format models based on sd-1."""
|
||||
|
||||
|
@ -55,7 +55,7 @@ class FastModelHash(object):
|
||||
for file in files:
|
||||
# Ignore the config files, which change locally,
|
||||
# and just look at the bin files.
|
||||
if file in ['config.json', 'model_index.json']:
|
||||
if file in ["config.json", "model_index.json"]:
|
||||
continue
|
||||
path = Path(root) / file
|
||||
fast_hash = cls._hash_file(path)
|
||||
|
@ -59,11 +59,12 @@ class ModelInstallBase(ABC):
|
||||
"""Abstract base class for InvokeAI model installation"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
logger: Optional[InvokeAILogger] = None
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
logger: Optional[InvokeAILogger] = None,
|
||||
):
|
||||
"""
|
||||
Create ModelInstall object.
|
||||
|
||||
@ -190,11 +191,12 @@ class ModelInstall(ModelInstallBase):
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
logger: Optional[InvokeAILogger] = None
|
||||
): # noqa D107 - use base class docstrings
|
||||
def __init__(
|
||||
self,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
logger: Optional[InvokeAILogger] = None,
|
||||
): # noqa D107 - use base class docstrings
|
||||
self._config = config or InvokeAIAppConfig.get_config()
|
||||
self._logger = logger or InvokeAILogger.getLogger()
|
||||
if store is None:
|
||||
@ -213,7 +215,7 @@ class ModelInstall(ModelInstallBase):
|
||||
name=model_path.stem,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
model_format=info.format
|
||||
model_format=info.format,
|
||||
)
|
||||
# add 'main' specific fields
|
||||
if info.model_type == ModelType.Main and info.format == ModelFormat.Checkpoint:
|
||||
@ -227,7 +229,7 @@ class ModelInstall(ModelInstallBase):
|
||||
self._store.add_model(id, registration_data)
|
||||
return id
|
||||
|
||||
def install(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
def install(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = ModelProbe.probe(model_path)
|
||||
dest_path = self._config.models_path / info.base_model.value / info.model_type.value / model_path.name
|
||||
@ -243,22 +245,22 @@ class ModelInstall(ModelInstallBase):
|
||||
info,
|
||||
)
|
||||
|
||||
def unregister(self, id: str): # noqa D102
|
||||
def unregister(self, id: str): # noqa D102
|
||||
self._store.del_model(id)
|
||||
|
||||
def delete(self, id: str): # noqa D102
|
||||
def delete(self, id: str): # noqa D102
|
||||
model = self._store.get_model(id)
|
||||
rmtree(model.path)
|
||||
self.unregister(id)
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
search = ModelSearch()
|
||||
search.model_found = self._scan_install if install else self._scan_register
|
||||
self._installed = set()
|
||||
search.search([scan_dir])
|
||||
return list(self._installed)
|
||||
|
||||
def garbage_collect(self) -> List[str]: # noqa D102
|
||||
def garbage_collect(self) -> List[str]: # noqa D102
|
||||
unregistered = list()
|
||||
for model in self._store.all_models():
|
||||
path = Path(model.path)
|
||||
@ -267,7 +269,7 @@ class ModelInstall(ModelInstallBase):
|
||||
unregistered.append(model.id)
|
||||
return unregistered
|
||||
|
||||
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
return FastModelHash.hash(model_path)
|
||||
|
||||
# the following two methods are callbacks to the ModelSearch object
|
||||
|
@ -30,6 +30,7 @@ from .util import SilenceWarnings, lora_token_vector_length
|
||||
class InvalidModelException(Exception):
|
||||
"""Raised when an invalid model is encountered."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelProbeInfo(object):
|
||||
"""Fields describing a probed model."""
|
||||
@ -49,9 +50,9 @@ class ModelProbeBase(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def probe(
|
||||
cls,
|
||||
model: Path,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
cls,
|
||||
model: Path,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> Optional[ModelProbeInfo]:
|
||||
"""
|
||||
Probe model located at path and return ModelProbeInfo object.
|
||||
@ -62,6 +63,7 @@ class ModelProbeBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ProbeBase(ABC):
|
||||
"""Base model for probing checkpoint and diffusers-style models."""
|
||||
|
||||
@ -102,9 +104,7 @@ class ModelProbe(ModelProbeBase):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: ModelFormat, model_type: ModelType, probe_class: ProbeBase
|
||||
):
|
||||
def register_probe(cls, format: ModelFormat, model_type: ModelType, probe_class: ProbeBase):
|
||||
"""
|
||||
Register a probe subclass to use when interrogating a model.
|
||||
|
||||
@ -123,13 +123,9 @@ class ModelProbe(ModelProbeBase):
|
||||
"""Probe model."""
|
||||
try:
|
||||
model_type = (
|
||||
cls.get_model_type_from_folder(model)
|
||||
if model.is_dir()
|
||||
else cls.get_model_type_from_checkpoint(model)
|
||||
cls.get_model_type_from_folder(model) if model.is_dir() else cls.get_model_type_from_checkpoint(model)
|
||||
)
|
||||
format_type = "onnx" if model_type == ModelType.ONNX \
|
||||
else "diffusers" if model.is_dir() \
|
||||
else "checkpoint"
|
||||
format_type = "onnx" if model_type == ModelType.ONNX else "diffusers" if model.is_dir() else "checkpoint"
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
@ -252,10 +248,12 @@ class ModelProbe(ModelProbeBase):
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||
|
||||
|
||||
# ##################################################3
|
||||
# Checkpoint probing
|
||||
# ##################################################3
|
||||
|
||||
|
||||
class CheckpointProbeBase(ProbeBase):
|
||||
"""Base class for probing checkpoint-style models."""
|
||||
|
||||
|
@ -132,7 +132,7 @@ class ModelSearch(ModelSearchBase):
|
||||
if self.search_completed:
|
||||
self.search_completed(self._model_set)
|
||||
|
||||
def list_models(self, directories: List[Union[Path,str]]) -> List[Path]:
|
||||
def list_models(self, directories: List[Union[Path, str]]) -> List[Path]:
|
||||
"""Return list of models found"""
|
||||
self.search(directories)
|
||||
return list(self._model_set)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.storage
|
||||
"""
|
||||
from .base import ModelConfigStore, UnknownModelException, DuplicateModelException # noqa F401
|
||||
from .base import ModelConfigStore, UnknownModelException, DuplicateModelException # noqa F401
|
||||
from .yaml import ModelConfigStoreYAML # noqa F401
|
||||
from .sql import ModelConfigStoreSQL # noqa F401
|
||||
|
@ -21,6 +21,7 @@ class InvalidModelException(Exception):
|
||||
class UnknownModelException(Exception):
|
||||
"""Raised on an attempt to delete a model with a nonexistent key."""
|
||||
|
||||
|
||||
class ModelConfigStore(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
|
@ -307,11 +307,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(record.base_model,
|
||||
record.model_type,
|
||||
record.name,
|
||||
record.path,
|
||||
json_serialized, key),
|
||||
(record.base_model, record.model_type, record.name, record.path, json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount < 1:
|
||||
raise UnknownModelException
|
||||
|
@ -12,6 +12,7 @@ from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
|
||||
class SilenceWarnings(object):
|
||||
"""
|
||||
Context manager that silences warnings from transformers and diffusers.
|
||||
@ -111,6 +112,7 @@ def lora_token_vector_length(checkpoint: dict) -> Optional[int]:
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
|
||||
def _fast_safetensors_reader(path: str):
|
||||
checkpoint = dict()
|
||||
device = torch.device("meta")
|
||||
@ -142,6 +144,7 @@ def _fast_safetensors_reader(path: str):
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
|
@ -43,14 +43,12 @@ class CreateTemplateScanner(ModelSearch):
|
||||
|
||||
def write_template(self, model: Path, info: ModelProbeInfo):
|
||||
"""Write template for a checkpoint file."""
|
||||
dest_path = Path(self._dest,
|
||||
"checkpoints" if model.is_file() else 'diffusers',
|
||||
info.base_type.value,
|
||||
info.model_type.value
|
||||
)
|
||||
template: dict = self._make_checkpoint_template(model) \
|
||||
if model.is_file() \
|
||||
else self._make_diffusers_template(model)
|
||||
dest_path = Path(
|
||||
self._dest, "checkpoints" if model.is_file() else "diffusers", info.base_type.value, info.model_type.value
|
||||
)
|
||||
template: dict = (
|
||||
self._make_checkpoint_template(model) if model.is_file() else self._make_diffusers_template(model)
|
||||
)
|
||||
if not template:
|
||||
print(f"Could not create template for {model}, got {template}")
|
||||
return
|
||||
@ -105,7 +103,7 @@ class CreateTemplateScanner(ModelSearch):
|
||||
tmpl = None
|
||||
if (model / "model_index.json").exists(): # a pipeline
|
||||
tmpl = {}
|
||||
for subdir in ['unet', 'text_encoder', 'vae', 'text_encoder_2']:
|
||||
for subdir in ["unet", "text_encoder", "vae", "text_encoder_2"]:
|
||||
config = model / subdir / "config.json"
|
||||
try:
|
||||
tmpl[subdir] = self._read_config(config)
|
||||
@ -122,7 +120,7 @@ class CreateTemplateScanner(ModelSearch):
|
||||
return tmpl
|
||||
|
||||
def _read_config(self, config: Path) -> dict:
|
||||
with open(config, 'r', encoding='utf-8') as f:
|
||||
with open(config, "r", encoding="utf-8") as f:
|
||||
return {x: y for x, y in json.load(f).items() if not x.startswith("_")}
|
||||
|
||||
def on_search_completed(self):
|
||||
@ -137,16 +135,14 @@ class CreateTemplateScanner(ModelSearch):
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Scan the provided path recursively and create .json templates for all models found.",
|
||||
)
|
||||
parser.add_argument("--scan",
|
||||
type=Path,
|
||||
help="Path to recursively scan for models"
|
||||
)
|
||||
parser.add_argument("--out",
|
||||
type=Path,
|
||||
dest="outdir",
|
||||
default=Path(__file__).resolve().parents[1] / "invokeai/configs/model_probe_templates",
|
||||
help="Destination for templates",
|
||||
)
|
||||
parser.add_argument("--scan", type=Path, help="Path to recursively scan for models")
|
||||
parser.add_argument(
|
||||
"--out",
|
||||
type=Path,
|
||||
dest="outdir",
|
||||
default=Path(__file__).resolve().parents[1] / "invokeai/configs/model_probe_templates",
|
||||
help="Destination for templates",
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
scanner = CreateTemplateScanner([opt.scan], dest=opt.outdir)
|
||||
|
@ -11,10 +11,12 @@ from invokeai.backend.model_manager import (
|
||||
InvalidModelException,
|
||||
)
|
||||
|
||||
|
||||
def helper(model_path: Path):
|
||||
print('Warning: guessing "v_prediction" SchedulerPredictionType', file=sys.stderr)
|
||||
return SchedulerPredictionType.VPrediction
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Probe model type")
|
||||
parser.add_argument(
|
||||
"model_path",
|
||||
|
Reference in New Issue
Block a user