Merge branch 'main' into refactor/model-manager-3

This commit is contained in:
Lincoln Stein 2023-11-27 22:15:51 -05:00 committed by GitHub
commit ecd3dcd5df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 142 additions and 163 deletions

View File

@ -1,14 +1,20 @@
import sys
from typing import Any from typing import Any
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from .services.config import InvokeAIAppConfig
# parse_args() must be called before any other imports. if it is not called first, consumers of the config # parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the # which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file. # values from the command line or config file.
from invokeai.version.invokeai_version import __version__
from .services.config import InvokeAIAppConfig
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
app_config.parse_args() app_config.parse_args()
if app_config.version:
print(f"InvokeAI version {__version__}")
sys.exit(0)
if True: # hack to make flake8 happy with imports coming after setting up the config if True: # hack to make flake8 happy with imports coming after setting up the config
import asyncio import asyncio
@ -34,7 +40,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
from invokeai.version.invokeai_version import __version__
from ..backend.util.logging import InvokeAILogger from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
@ -51,7 +56,12 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
workflows, workflows,
) )
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
_InputField,
_OutputField,
)
if is_mps_available(): if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
@ -273,7 +283,4 @@ def invoke_api() -> None:
if __name__ == "__main__": if __name__ == "__main__":
if app_config.version: invoke_api()
print(f"InvokeAI version {__version__}")
else:
invoke_api()

View File

@ -5,7 +5,7 @@ from pathlib import Path
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config.config_default import InvokeAIAppConfig
custom_nodes_path = Path(InvokeAIAppConfig.get_config().custom_nodes_path.absolute()) custom_nodes_path = Path(InvokeAIAppConfig.get_config().custom_nodes_path.resolve())
custom_nodes_path.mkdir(parents=True, exist_ok=True) custom_nodes_path.mkdir(parents=True, exist_ok=True)
custom_nodes_init_path = str(custom_nodes_path / "__init__.py") custom_nodes_init_path = str(custom_nodes_path / "__init__.py")

View File

@ -15,7 +15,7 @@ import os
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf import DictConfig, ListConfig, OmegaConf
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@ -24,10 +24,7 @@ from invokeai.app.services.config.config_common import PagingArgumentParser, int
class InvokeAISettings(BaseSettings): class InvokeAISettings(BaseSettings):
""" """Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
Runtime configuration settings in which default values are
read from an omegaconf .yaml file.
"""
initconf: ClassVar[Optional[DictConfig]] = None initconf: ClassVar[Optional[DictConfig]] = None
argparse_groups: ClassVar[Dict] = {} argparse_groups: ClassVar[Dict] = {}
@ -35,6 +32,7 @@ class InvokeAISettings(BaseSettings):
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True) model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
def parse_args(self, argv: Optional[list] = sys.argv[1:]): def parse_args(self, argv: Optional[list] = sys.argv[1:]):
"""Call to parse command-line arguments."""
parser = self.get_parser() parser = self.get_parser()
opt, unknown_opts = parser.parse_known_args(argv) opt, unknown_opts = parser.parse_known_args(argv)
if len(unknown_opts) > 0: if len(unknown_opts) > 0:
@ -49,20 +47,19 @@ class InvokeAISettings(BaseSettings):
setattr(self, name, value) setattr(self, name, value)
def to_yaml(self) -> str: def to_yaml(self) -> str:
""" """Return a YAML string representing our settings. This can be used as the contents of `invokeai.yaml` to restore settings later."""
Return a YAML string representing our settings. This can be used
as the contents of `invokeai.yaml` to restore settings later.
"""
cls = self.__class__ cls = self.__class__
type = get_args(get_type_hints(cls)["type"])[0] type = get_args(get_type_hints(cls)["type"])[0]
field_dict = {type: {}} field_dict: Dict[str, Dict[str, Any]] = {type: {}}
for name, field in self.model_fields.items(): for name, field in self.model_fields.items():
if name in cls._excluded_from_yaml(): if name in cls._excluded_from_yaml():
continue continue
assert isinstance(field.json_schema_extra, dict)
category = ( category = (
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized" field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
) )
value = getattr(self, name) value = getattr(self, name)
assert isinstance(category, str)
if category not in field_dict[type]: if category not in field_dict[type]:
field_dict[type][category] = {} field_dict[type][category] = {}
# keep paths as strings to make it easier to read # keep paths as strings to make it easier to read
@ -72,6 +69,7 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def add_parser_arguments(cls, parser): def add_parser_arguments(cls, parser):
"""Dynamically create arguments for a settings parser."""
if "type" in get_type_hints(cls): if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)["type"])[0] settings_stanza = get_args(get_type_hints(cls)["type"])[0]
else: else:
@ -116,6 +114,7 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def cmd_name(cls, command_field: str = "type") -> str: def cmd_name(cls, command_field: str = "type") -> str:
"""Return the category of a setting."""
hints = get_type_hints(cls) hints = get_type_hints(cls)
if command_field in hints: if command_field in hints:
return get_args(hints[command_field])[0] return get_args(hints[command_field])[0]
@ -124,6 +123,7 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def get_parser(cls) -> ArgumentParser: def get_parser(cls) -> ArgumentParser:
"""Get the command-line parser for a setting."""
parser = PagingArgumentParser( parser = PagingArgumentParser(
prog=cls.cmd_name(), prog=cls.cmd_name(),
description=cls.__doc__, description=cls.__doc__,
@ -152,10 +152,14 @@ class InvokeAISettings(BaseSettings):
"free_gpu_mem", "free_gpu_mem",
"xformers_enabled", "xformers_enabled",
"tiled_decode", "tiled_decode",
"lora_dir",
"embedding_dir",
"controlnet_dir",
] ]
@classmethod @classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override=None): def add_field_argument(cls, command_parser, name: str, field, default_override=None):
"""Add the argparse arguments for a setting parser."""
field_type = get_type_hints(cls).get(name) field_type = get_type_hints(cls).get(name)
default = ( default = (
default_override default_override

View File

@ -177,6 +177,7 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from pydantic import Field, TypeAdapter from pydantic import Field, TypeAdapter
from pydantic.config import JsonDict
from pydantic_settings import SettingsConfigDict from pydantic_settings import SettingsConfigDict
from .config_base import InvokeAISettings from .config_base import InvokeAISettings
@ -188,28 +189,24 @@ DEFAULT_MAX_VRAM = 0.5
class Categories(object): class Categories(object):
WebServer = {"category": "Web Server"} """Category headers for configuration variable groups."""
Features = {"category": "Features"}
Paths = {"category": "Paths"} WebServer: JsonDict = {"category": "Web Server"}
Logging = {"category": "Logging"} Features: JsonDict = {"category": "Features"}
Development = {"category": "Development"} Paths: JsonDict = {"category": "Paths"}
Other = {"category": "Other"} Logging: JsonDict = {"category": "Logging"}
ModelCache = {"category": "Model Cache"} Development: JsonDict = {"category": "Development"}
Device = {"category": "Device"} Other: JsonDict = {"category": "Other"}
Generation = {"category": "Generation"} ModelCache: JsonDict = {"category": "Model Cache"}
Queue = {"category": "Queue"} Device: JsonDict = {"category": "Device"}
Nodes = {"category": "Nodes"} Generation: JsonDict = {"category": "Generation"}
MemoryPerformance = {"category": "Memory/Performance"} Queue: JsonDict = {"category": "Queue"}
Nodes: JsonDict = {"category": "Nodes"}
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
class InvokeAIAppConfig(InvokeAISettings): class InvokeAIAppConfig(InvokeAISettings):
""" """Configuration object for InvokeAI App."""
Generate images using Stable Diffusion. Use "invokeai" to launch
the command-line client (recommended for experts only), or
"invokeai-web" to launch the web server. Global options
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
setting environment variables INVOKEAI_<setting>.
"""
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
singleton_init: ClassVar[Optional[Dict]] = None singleton_init: ClassVar[Optional[Dict]] = None
@ -234,15 +231,12 @@ class InvokeAIAppConfig(InvokeAISettings):
# PATHS # PATHS
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths) root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
autoimport_dir : Optional[Path] = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths) autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths) conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths) models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths) legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
conf_path : Optional[Path] = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths) db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
models_dir : Optional[Path] = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths) outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
legacy_conf_dir : Optional[Path] = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
db_dir : Optional[Path] = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
outdir : Optional[Path] = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths) use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes', json_schema_extra=Categories.Paths) custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes', json_schema_extra=Categories.Paths)
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths) from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
@ -285,11 +279,15 @@ class InvokeAIAppConfig(InvokeAISettings):
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES # DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance) always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance) max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance) max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance) xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance) tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
# this is not referred to in the source code and can be removed entirely
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories # See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
# fmt: on # fmt: on
@ -303,8 +301,8 @@ class InvokeAIAppConfig(InvokeAISettings):
clobber=False, clobber=False,
): ):
""" """
Update settings with contents of init file, environment, and Update settings with contents of init file, environment, and command-line settings.
command-line settings.
:param conf: alternate Omegaconf dictionary object :param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list :param argv: aternate sys.argv list
:param clobber: ovewrite any initialization parameters passed during initialization :param clobber: ovewrite any initialization parameters passed during initialization
@ -349,9 +347,7 @@ class InvokeAIAppConfig(InvokeAISettings):
@property @property
def root_path(self) -> Path: def root_path(self) -> Path:
""" """Path to the runtime root directory."""
Path to the runtime root directory
"""
if self.root: if self.root:
root = Path(self.root).expanduser().absolute() root = Path(self.root).expanduser().absolute()
else: else:
@ -361,9 +357,7 @@ class InvokeAIAppConfig(InvokeAISettings):
@property @property
def root_dir(self) -> Path: def root_dir(self) -> Path:
""" """Alias for above."""
Alias for above.
"""
return self.root_path return self.root_path
def _resolve(self, partial_path: Path) -> Path: def _resolve(self, partial_path: Path) -> Path:
@ -371,108 +365,95 @@ class InvokeAIAppConfig(InvokeAISettings):
@property @property
def init_file_path(self) -> Path: def init_file_path(self) -> Path:
""" """Path to invokeai.yaml."""
Path to invokeai.yaml resolved_path = self._resolve(INIT_FILE)
""" assert resolved_path is not None
return self._resolve(INIT_FILE) return resolved_path
@property @property
def output_path(self) -> Path: def output_path(self) -> Optional[Path]:
""" """Path to defaults outputs directory."""
Path to defaults outputs directory.
"""
return self._resolve(self.outdir) return self._resolve(self.outdir)
@property @property
def db_path(self) -> Path: def db_path(self) -> Path:
""" """Path to the invokeai.db file."""
Path to the invokeai.db file. db_dir = self._resolve(self.db_dir)
""" assert db_dir is not None
return self._resolve(self.db_dir) / DB_FILE return db_dir / DB_FILE
@property @property
def model_conf_path(self) -> Path: def model_conf_path(self) -> Optional[Path]:
""" """Path to models configuration file."""
Path to models configuration file.
"""
return self._resolve(self.conf_path) return self._resolve(self.conf_path)
@property @property
def legacy_conf_path(self) -> Path: def legacy_conf_path(self) -> Optional[Path]:
""" """Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
"""
return self._resolve(self.legacy_conf_dir) return self._resolve(self.legacy_conf_dir)
@property @property
def models_path(self) -> Path: def models_path(self) -> Optional[Path]:
""" """Path to the models directory."""
Path to the models directory
"""
return self._resolve(self.models_dir) return self._resolve(self.models_dir)
@property @property
def custom_nodes_path(self) -> Path: def custom_nodes_path(self) -> Path:
""" """Path to the custom nodes directory."""
Path to the custom nodes directory custom_nodes_path = self._resolve(self.custom_nodes_dir)
""" assert custom_nodes_path is not None
return self._resolve(self.custom_nodes_dir) return custom_nodes_path
# the following methods support legacy calls leftover from the Globals era # the following methods support legacy calls leftover from the Globals era
@property @property
def full_precision(self) -> bool: def full_precision(self) -> bool:
"""Return true if precision set to float32""" """Return true if precision set to float32."""
return self.precision == "float32" return self.precision == "float32"
@property @property
def try_patchmatch(self) -> bool: def try_patchmatch(self) -> bool:
"""Return true if patchmatch true""" """Return true if patchmatch true."""
return self.patchmatch return self.patchmatch
@property @property
def nsfw_checker(self) -> bool: def nsfw_checker(self) -> bool:
"""NSFW node is always active and disabled from Web UIe""" """Return value for NSFW checker. The NSFW node is always active and disabled from Web UI."""
return True return True
@property @property
def invisible_watermark(self) -> bool: def invisible_watermark(self) -> bool:
"""invisible watermark node is always active and disabled from Web UIe""" """Return value of invisible watermark. It is always active and disabled from Web UI."""
return True return True
@property @property
def ram_cache_size(self) -> Union[Literal["auto"], float]: def ram_cache_size(self) -> Union[Literal["auto"], float]:
"""Return the ram cache size using the legacy or modern setting."""
return self.max_cache_size or self.ram return self.max_cache_size or self.ram
@property @property
def vram_cache_size(self) -> Union[Literal["auto"], float]: def vram_cache_size(self) -> Union[Literal["auto"], float]:
"""Return the vram cache size using the legacy or modern setting."""
return self.max_vram_cache_size or self.vram return self.max_vram_cache_size or self.vram
@property @property
def use_cpu(self) -> bool: def use_cpu(self) -> bool:
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""
return self.always_use_cpu or self.device == "cpu" return self.always_use_cpu or self.device == "cpu"
@property @property
def disable_xformers(self) -> bool: def disable_xformers(self) -> bool:
""" """Return true if enable_xformers is false (reversed logic) and attention type is not set to xformers."""
Return true if enable_xformers is false (reversed logic)
and attention type is not set to xformers.
"""
disabled_in_config = not self.xformers_enabled disabled_in_config = not self.xformers_enabled
return disabled_in_config and self.attention_type != "xformers" return disabled_in_config and self.attention_type != "xformers"
@staticmethod @staticmethod
def find_root() -> Path: def find_root() -> Path:
""" """Choose the runtime root directory when not specified on command line or init file."""
Choose the runtime root directory when not specified on command line or
init file.
"""
return _find_root() return _find_root()
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig: def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
""" """Legacy function which returns InvokeAIAppConfig.get_config()."""
Legacy function which returns InvokeAIAppConfig.get_config()
"""
return InvokeAIAppConfig.get_config(**kwargs) return InvokeAIAppConfig.get_config(**kwargs)

View File

@ -48,7 +48,6 @@ from typing import List, Optional, Union
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ModelConfigBase,
ModelConfigFactory, ModelConfigFactory,
ModelType, ModelType,
) )
@ -158,7 +157,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
("version", CONFIG_FILE_VERSION), ("version", CONFIG_FILE_VERSION),
) )
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig: def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
""" """
Add a model to the database. Add a model to the database.
@ -255,7 +254,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback() self._db.conn.rollback()
raise e raise e
def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig: def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
""" """
Update the model, returning the updated version. Update the model, returning the updated version.
@ -368,7 +367,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results return results
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]: def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path.""" """Return models with the indicated path."""
results = [] results = []
with self._db.lock: with self._db.lock:
@ -382,7 +381,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results return results
def search_by_hash(self, hash: str) -> List[ModelConfigBase]: def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated original_hash.""" """Return models with the indicated original_hash."""
results = [] results = []
with self._db.lock: with self._db.lock:

View File

@ -238,7 +238,7 @@ class ModelProbe(object):
with SilenceWarnings(): with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")): if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path) cls._scan_model(model_path, model_path)
return torch.load(model_path) return torch.load(model_path, map_location="cpu")
else: else:
return safetensors.torch.load_file(model_path) return safetensors.torch.load_file(model_path)

View File

@ -1,8 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team # Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""invokeai.backend.util.logging """
Logging class for InvokeAI that produces console messages.
Logging class for InvokeAI that produces console messages
Usage: Usage:
@ -178,8 +177,8 @@ InvokeAI:
import logging.handlers import logging.handlers
import socket import socket
import urllib.parse import urllib.parse
from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
@ -192,36 +191,36 @@ except ImportError:
# module level functions # module level functions
def debug(msg, *args, **kwargs): def debug(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().debug(msg, *args, **kwargs) InvokeAILogger.get_logger().debug(msg, *args, **kwargs)
def info(msg, *args, **kwargs): def info(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().info(msg, *args, **kwargs) InvokeAILogger.get_logger().info(msg, *args, **kwargs)
def warning(msg, *args, **kwargs): def warning(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().warning(msg, *args, **kwargs) InvokeAILogger.get_logger().warning(msg, *args, **kwargs)
def error(msg, *args, **kwargs): def error(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().error(msg, *args, **kwargs) InvokeAILogger.get_logger().error(msg, *args, **kwargs)
def critical(msg, *args, **kwargs): def critical(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().critical(msg, *args, **kwargs) InvokeAILogger.get_logger().critical(msg, *args, **kwargs)
def log(level, msg, *args, **kwargs): def log(level: int, msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().log(level, msg, *args, **kwargs) InvokeAILogger.get_logger().log(level, msg, *args, **kwargs)
def disable(level=logging.CRITICAL): def disable(level: int = logging.CRITICAL) -> None: # noqa D103
InvokeAILogger.get_logger().disable(level) logging.disable(level)
def basicConfig(**kwargs): def basicConfig(**kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().basicConfig(**kwargs) logging.basicConfig(**kwargs)
_FACILITY_MAP = ( _FACILITY_MAP = (
@ -256,33 +255,25 @@ _SOCK_MAP = {
class InvokeAIFormatter(logging.Formatter): class InvokeAIFormatter(logging.Formatter):
""" """Base class for logging formatter."""
Base class for logging formatter
""" def format(self, record: logging.LogRecord) -> str: # noqa D102
def format(self, record):
formatter = logging.Formatter(self.log_fmt(record.levelno)) formatter = logging.Formatter(self.log_fmt(record.levelno))
return formatter.format(record) return formatter.format(record)
@abstractmethod def log_fmt(self, levelno: int) -> str: # noqa D102
def log_fmt(self, levelno: int) -> str: return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
pass
class InvokeAISyslogFormatter(InvokeAIFormatter): class InvokeAISyslogFormatter(InvokeAIFormatter):
""" """Formatting for syslog."""
Formatting for syslog
"""
def log_fmt(self, levelno: int) -> str: def log_fmt(self, levelno: int) -> str: # noqa D102
return "%(name)s [%(process)d] <%(levelname)s> %(message)s" return "%(name)s [%(process)d] <%(levelname)s> %(message)s"
class InvokeAILegacyLogFormatter(InvokeAIFormatter): class InvokeAILegacyLogFormatter(InvokeAIFormatter): # noqa D102
""" """Formatting for the InvokeAI Logger (legacy version)."""
Formatting for the InvokeAI Logger (legacy version)
"""
FORMATS = { FORMATS = {
logging.DEBUG: " | %(message)s", logging.DEBUG: " | %(message)s",
@ -292,23 +283,21 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter):
logging.CRITICAL: "### %(message)s", logging.CRITICAL: "### %(message)s",
} }
def log_fmt(self, levelno: int) -> str: def log_fmt(self, levelno: int) -> str: # noqa D102
return self.FORMATS.get(levelno) format = self.FORMATS.get(levelno)
assert format is not None
return format
class InvokeAIPlainLogFormatter(InvokeAIFormatter): class InvokeAIPlainLogFormatter(InvokeAIFormatter):
""" """Custom Formatting for the InvokeAI Logger (plain version)."""
Custom Formatting for the InvokeAI Logger (plain version)
"""
def log_fmt(self, levelno: int) -> str: def log_fmt(self, levelno: int) -> str: # noqa D102
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s" return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
class InvokeAIColorLogFormatter(InvokeAIFormatter): class InvokeAIColorLogFormatter(InvokeAIFormatter):
""" """Custom Formatting for the InvokeAI Logger."""
Custom Formatting for the InvokeAI Logger
"""
# Color Codes # Color Codes
grey = "\x1b[38;20m" grey = "\x1b[38;20m"
@ -331,8 +320,10 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter):
logging.CRITICAL: bold_red + log_format + reset, logging.CRITICAL: bold_red + log_format + reset,
} }
def log_fmt(self, levelno: int) -> str: def log_fmt(self, levelno: int) -> str: # noqa D102
return self.FORMATS.get(levelno) format = self.FORMATS.get(levelno)
assert format is not None
return format
LOG_FORMATTERS = { LOG_FORMATTERS = {
@ -343,13 +334,13 @@ LOG_FORMATTERS = {
} }
class InvokeAILogger(object): class InvokeAILogger(object): # noqa D102
loggers = {} loggers: Dict[str, logging.Logger] = {}
@classmethod @classmethod
def get_logger( def get_logger(
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config() cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
) -> logging.Logger: ) -> logging.Logger: # noqa D102
if name in cls.loggers: if name in cls.loggers:
logger = cls.loggers[name] logger = cls.loggers[name]
logger.handlers.clear() logger.handlers.clear()
@ -362,7 +353,7 @@ class InvokeAILogger(object):
return cls.loggers[name] return cls.loggers[name]
@classmethod @classmethod
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]: def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]: # noqa D102
handler_strs = config.log_handlers handler_strs = config.log_handlers
handlers = [] handlers = []
for handler in handler_strs: for handler in handler_strs:
@ -374,7 +365,7 @@ class InvokeAILogger(object):
# http gets no custom formatter # http gets no custom formatter
formatter = LOG_FORMATTERS[config.log_format] formatter = LOG_FORMATTERS[config.log_format]
if handler_name == "console": if handler_name == "console":
ch = logging.StreamHandler() ch: logging.Handler = logging.StreamHandler()
ch.setFormatter(formatter()) ch.setFormatter(formatter())
handlers.append(ch) handlers.append(ch)
@ -393,18 +384,18 @@ class InvokeAILogger(object):
return handlers return handlers
@staticmethod @staticmethod
def _parse_syslog_args(args: str = None) -> logging.Handler: def _parse_syslog_args(args: Optional[str] = None) -> logging.Handler:
if not SYSLOG_AVAILABLE: if not SYSLOG_AVAILABLE:
raise ValueError("syslog is not available on this system") raise ValueError("syslog is not available on this system")
if not args: if not args:
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514" args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
syslog_args = {} syslog_args: Dict[str, Any] = {}
try: try:
for a in args.split(","): for a in args.split(","):
arg_name, *arg_value = a.split(":", 2) arg_name, *arg_value = a.split(":", 2)
if arg_name == "address": if arg_name == "address":
host, *port = arg_value host, *port_list = arg_value
port = 514 if len(port) == 0 else int(port[0]) port = 514 if not port_list else int(port_list[0])
syslog_args["address"] = (host, port) syslog_args["address"] = (host, port)
elif arg_name == "facility": elif arg_name == "facility":
syslog_args["facility"] = _FACILITY_MAP[arg_value[0]] syslog_args["facility"] = _FACILITY_MAP[arg_value[0]]
@ -417,13 +408,13 @@ class InvokeAILogger(object):
return logging.handlers.SysLogHandler(**syslog_args) return logging.handlers.SysLogHandler(**syslog_args)
@staticmethod @staticmethod
def _parse_file_args(args: str = None) -> logging.Handler: def _parse_file_args(args: Optional[str] = None) -> logging.Handler: # noqa D102
if not args: if not args:
raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'") raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'")
return logging.FileHandler(args) return logging.FileHandler(args)
@staticmethod @staticmethod
def _parse_http_args(args: str = None) -> logging.Handler: def _parse_http_args(args: Optional[str] = None) -> logging.Handler: # noqa D102
if not args: if not args:
raise ValueError("please provide destination for http logging using format 'http=url'") raise ValueError("please provide destination for http logging using format 'http=url'")
arg_list = args.split(",") arg_list = args.split(",")
@ -434,12 +425,12 @@ class InvokeAILogger(object):
path = url.path path = url.path
port = url.port or 80 port = url.port or 80
syslog_args = {} syslog_args: Dict[str, Any] = {}
for a in arg_list: for a in arg_list:
arg_name, *arg_value = a.split(":", 2) arg_name, *arg_value = a.split(":", 2)
if arg_name == "method": if arg_name == "method":
arg_value = arg_value[0] if len(arg_value) > 0 else "GET" method = arg_value[0] if len(arg_value) > 0 else "GET"
syslog_args[arg_name] = arg_value syslog_args[arg_name] = method
else: # TODO: Provide support for SSL context and credentials else: # TODO: Provide support for SSL context and credentials
pass pass
return logging.handlers.HTTPHandler(f"{host}:{port}", path, **syslog_args) return logging.handlers.HTTPHandler(f"{host}:{port}", path, **syslog_args)

View File

@ -229,8 +229,6 @@ module = [
"invokeai.app.api.routers.models", "invokeai.app.api.routers.models",
"invokeai.app.invocations.compel", "invokeai.app.invocations.compel",
"invokeai.app.invocations.latent", "invokeai.app.invocations.latent",
"invokeai.app.services.config.config_base",
"invokeai.app.services.config.config_default",
"invokeai.app.services.invocation_stats.invocation_stats_default", "invokeai.app.services.invocation_stats.invocation_stats_default",
"invokeai.app.services.model_manager.model_manager_base", "invokeai.app.services.model_manager.model_manager_base",
"invokeai.app.services.model_manager.model_manager_default", "invokeai.app.services.model_manager.model_manager_default",
@ -265,7 +263,6 @@ module = [
"invokeai.backend.stable_diffusion.diffusion.cross_attention_control", "invokeai.backend.stable_diffusion.diffusion.cross_attention_control",
"invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion", "invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion",
"invokeai.backend.util.hotfixes", "invokeai.backend.util.hotfixes",
"invokeai.backend.util.logging",
"invokeai.backend.util.mps_fixes", "invokeai.backend.util.mps_fixes",
"invokeai.backend.util.util", "invokeai.backend.util.util",
"invokeai.frontend.install.model_install", "invokeai.frontend.install.model_install",