mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into refactor/model-manager-3
This commit is contained in:
commit
ecd3dcd5df
@ -1,14 +1,20 @@
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||
# values from the command line or config file.
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||
import asyncio
|
||||
@ -34,7 +40,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
@ -51,7 +56,12 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
_InputField,
|
||||
_OutputField,
|
||||
)
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
@ -273,7 +283,4 @@ def invoke_api() -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_api()
|
||||
invoke_api()
|
||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
|
||||
custom_nodes_path = Path(InvokeAIAppConfig.get_config().custom_nodes_path.absolute())
|
||||
custom_nodes_path = Path(InvokeAIAppConfig.get_config().custom_nodes_path.resolve())
|
||||
custom_nodes_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
custom_nodes_init_path = str(custom_nodes_path / "__init__.py")
|
||||
|
@ -15,7 +15,7 @@ import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
@ -24,10 +24,7 @@ from invokeai.app.services.config.config_common import PagingArgumentParser, int
|
||||
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
"""
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
"""
|
||||
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
|
||||
|
||||
initconf: ClassVar[Optional[DictConfig]] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
@ -35,6 +32,7 @@ class InvokeAISettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
||||
|
||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||
"""Call to parse command-line arguments."""
|
||||
parser = self.get_parser()
|
||||
opt, unknown_opts = parser.parse_known_args(argv)
|
||||
if len(unknown_opts) > 0:
|
||||
@ -49,20 +47,19 @@ class InvokeAISettings(BaseSettings):
|
||||
setattr(self, name, value)
|
||||
|
||||
def to_yaml(self) -> str:
|
||||
"""
|
||||
Return a YAML string representing our settings. This can be used
|
||||
as the contents of `invokeai.yaml` to restore settings later.
|
||||
"""
|
||||
"""Return a YAML string representing our settings. This can be used as the contents of `invokeai.yaml` to restore settings later."""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = {type: {}}
|
||||
field_dict: Dict[str, Dict[str, Any]] = {type: {}}
|
||||
for name, field in self.model_fields.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
assert isinstance(field.json_schema_extra, dict)
|
||||
category = (
|
||||
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
|
||||
)
|
||||
value = getattr(self, name)
|
||||
assert isinstance(category, str)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = {}
|
||||
# keep paths as strings to make it easier to read
|
||||
@ -72,6 +69,7 @@ class InvokeAISettings(BaseSettings):
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
"""Dynamically create arguments for a settings parser."""
|
||||
if "type" in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||
else:
|
||||
@ -116,6 +114,7 @@ class InvokeAISettings(BaseSettings):
|
||||
|
||||
@classmethod
|
||||
def cmd_name(cls, command_field: str = "type") -> str:
|
||||
"""Return the category of a setting."""
|
||||
hints = get_type_hints(cls)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
@ -124,6 +123,7 @@ class InvokeAISettings(BaseSettings):
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls) -> ArgumentParser:
|
||||
"""Get the command-line parser for a setting."""
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
@ -152,10 +152,14 @@ class InvokeAISettings(BaseSettings):
|
||||
"free_gpu_mem",
|
||||
"xformers_enabled",
|
||||
"tiled_decode",
|
||||
"lora_dir",
|
||||
"embedding_dir",
|
||||
"controlnet_dir",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
"""Add the argparse arguments for a setting parser."""
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = (
|
||||
default_override
|
||||
|
@ -177,6 +177,7 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field, TypeAdapter
|
||||
from pydantic.config import JsonDict
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from .config_base import InvokeAISettings
|
||||
@ -188,28 +189,24 @@ DEFAULT_MAX_VRAM = 0.5
|
||||
|
||||
|
||||
class Categories(object):
|
||||
WebServer = {"category": "Web Server"}
|
||||
Features = {"category": "Features"}
|
||||
Paths = {"category": "Paths"}
|
||||
Logging = {"category": "Logging"}
|
||||
Development = {"category": "Development"}
|
||||
Other = {"category": "Other"}
|
||||
ModelCache = {"category": "Model Cache"}
|
||||
Device = {"category": "Device"}
|
||||
Generation = {"category": "Generation"}
|
||||
Queue = {"category": "Queue"}
|
||||
Nodes = {"category": "Nodes"}
|
||||
MemoryPerformance = {"category": "Memory/Performance"}
|
||||
"""Category headers for configuration variable groups."""
|
||||
|
||||
WebServer: JsonDict = {"category": "Web Server"}
|
||||
Features: JsonDict = {"category": "Features"}
|
||||
Paths: JsonDict = {"category": "Paths"}
|
||||
Logging: JsonDict = {"category": "Logging"}
|
||||
Development: JsonDict = {"category": "Development"}
|
||||
Other: JsonDict = {"category": "Other"}
|
||||
ModelCache: JsonDict = {"category": "Model Cache"}
|
||||
Device: JsonDict = {"category": "Device"}
|
||||
Generation: JsonDict = {"category": "Generation"}
|
||||
Queue: JsonDict = {"category": "Queue"}
|
||||
Nodes: JsonDict = {"category": "Nodes"}
|
||||
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
the command-line client (recommended for experts only), or
|
||||
"invokeai-web" to launch the web server. Global options
|
||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||
setting environment variables INVOKEAI_<setting>.
|
||||
"""
|
||||
"""Configuration object for InvokeAI App."""
|
||||
|
||||
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
|
||||
singleton_init: ClassVar[Optional[Dict]] = None
|
||||
@ -234,15 +231,12 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
# PATHS
|
||||
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
|
||||
autoimport_dir : Optional[Path] = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Optional[Path] = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
models_dir : Optional[Path] = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Optional[Path] = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||
db_dir : Optional[Path] = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||
outdir : Optional[Path] = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
|
||||
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes', json_schema_extra=Categories.Paths)
|
||||
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
|
||||
@ -285,11 +279,15 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
|
||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
# this is not referred to in the source code and can be removed entirely
|
||||
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||
|
||||
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
||||
# fmt: on
|
||||
@ -303,8 +301,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
clobber=False,
|
||||
):
|
||||
"""
|
||||
Update settings with contents of init file, environment, and
|
||||
command-line settings.
|
||||
Update settings with contents of init file, environment, and command-line settings.
|
||||
|
||||
:param conf: alternate Omegaconf dictionary object
|
||||
:param argv: aternate sys.argv list
|
||||
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||
@ -349,9 +347,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
@property
|
||||
def root_path(self) -> Path:
|
||||
"""
|
||||
Path to the runtime root directory
|
||||
"""
|
||||
"""Path to the runtime root directory."""
|
||||
if self.root:
|
||||
root = Path(self.root).expanduser().absolute()
|
||||
else:
|
||||
@ -361,9 +357,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
@property
|
||||
def root_dir(self) -> Path:
|
||||
"""
|
||||
Alias for above.
|
||||
"""
|
||||
"""Alias for above."""
|
||||
return self.root_path
|
||||
|
||||
def _resolve(self, partial_path: Path) -> Path:
|
||||
@ -371,108 +365,95 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
@property
|
||||
def init_file_path(self) -> Path:
|
||||
"""
|
||||
Path to invokeai.yaml
|
||||
"""
|
||||
return self._resolve(INIT_FILE)
|
||||
"""Path to invokeai.yaml."""
|
||||
resolved_path = self._resolve(INIT_FILE)
|
||||
assert resolved_path is not None
|
||||
return resolved_path
|
||||
|
||||
@property
|
||||
def output_path(self) -> Path:
|
||||
"""
|
||||
Path to defaults outputs directory.
|
||||
"""
|
||||
def output_path(self) -> Optional[Path]:
|
||||
"""Path to defaults outputs directory."""
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def db_path(self) -> Path:
|
||||
"""
|
||||
Path to the invokeai.db file.
|
||||
"""
|
||||
return self._resolve(self.db_dir) / DB_FILE
|
||||
"""Path to the invokeai.db file."""
|
||||
db_dir = self._resolve(self.db_dir)
|
||||
assert db_dir is not None
|
||||
return db_dir / DB_FILE
|
||||
|
||||
@property
|
||||
def model_conf_path(self) -> Path:
|
||||
"""
|
||||
Path to models configuration file.
|
||||
"""
|
||||
def model_conf_path(self) -> Optional[Path]:
|
||||
"""Path to models configuration file."""
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def legacy_conf_path(self) -> Path:
|
||||
"""
|
||||
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
||||
"""
|
||||
def legacy_conf_path(self) -> Optional[Path]:
|
||||
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
|
||||
return self._resolve(self.legacy_conf_dir)
|
||||
|
||||
@property
|
||||
def models_path(self) -> Path:
|
||||
"""
|
||||
Path to the models directory
|
||||
"""
|
||||
def models_path(self) -> Optional[Path]:
|
||||
"""Path to the models directory."""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def custom_nodes_path(self) -> Path:
|
||||
"""
|
||||
Path to the custom nodes directory
|
||||
"""
|
||||
return self._resolve(self.custom_nodes_dir)
|
||||
"""Path to the custom nodes directory."""
|
||||
custom_nodes_path = self._resolve(self.custom_nodes_dir)
|
||||
assert custom_nodes_path is not None
|
||||
return custom_nodes_path
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self) -> bool:
|
||||
"""Return true if precision set to float32"""
|
||||
"""Return true if precision set to float32."""
|
||||
return self.precision == "float32"
|
||||
|
||||
@property
|
||||
def try_patchmatch(self) -> bool:
|
||||
"""Return true if patchmatch true"""
|
||||
"""Return true if patchmatch true."""
|
||||
return self.patchmatch
|
||||
|
||||
@property
|
||||
def nsfw_checker(self) -> bool:
|
||||
"""NSFW node is always active and disabled from Web UIe"""
|
||||
"""Return value for NSFW checker. The NSFW node is always active and disabled from Web UI."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def invisible_watermark(self) -> bool:
|
||||
"""invisible watermark node is always active and disabled from Web UIe"""
|
||||
"""Return value of invisible watermark. It is always active and disabled from Web UI."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||
"""Return the ram cache size using the legacy or modern setting."""
|
||||
return self.max_cache_size or self.ram
|
||||
|
||||
@property
|
||||
def vram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||
"""Return the vram cache size using the legacy or modern setting."""
|
||||
return self.max_vram_cache_size or self.vram
|
||||
|
||||
@property
|
||||
def use_cpu(self) -> bool:
|
||||
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""
|
||||
return self.always_use_cpu or self.device == "cpu"
|
||||
|
||||
@property
|
||||
def disable_xformers(self) -> bool:
|
||||
"""
|
||||
Return true if enable_xformers is false (reversed logic)
|
||||
and attention type is not set to xformers.
|
||||
"""
|
||||
"""Return true if enable_xformers is false (reversed logic) and attention type is not set to xformers."""
|
||||
disabled_in_config = not self.xformers_enabled
|
||||
return disabled_in_config and self.attention_type != "xformers"
|
||||
|
||||
@staticmethod
|
||||
def find_root() -> Path:
|
||||
"""
|
||||
Choose the runtime root directory when not specified on command line or
|
||||
init file.
|
||||
"""
|
||||
"""Choose the runtime root directory when not specified on command line or init file."""
|
||||
return _find_root()
|
||||
|
||||
|
||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||
"""
|
||||
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
|
||||
|
||||
|
@ -48,7 +48,6 @@ from typing import List, Optional, Union
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
@ -158,7 +157,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
("version", CONFIG_FILE_VERSION),
|
||||
)
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig:
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -255,7 +254,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig:
|
||||
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
@ -368,7 +367,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated path."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
@ -382,7 +381,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated original_hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
|
@ -238,7 +238,7 @@ class ModelProbe(object):
|
||||
with SilenceWarnings():
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model_path, model_path)
|
||||
return torch.load(model_path)
|
||||
return torch.load(model_path, map_location="cpu")
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
|
||||
|
||||
"""invokeai.backend.util.logging
|
||||
|
||||
Logging class for InvokeAI that produces console messages
|
||||
"""
|
||||
Logging class for InvokeAI that produces console messages.
|
||||
|
||||
Usage:
|
||||
|
||||
@ -178,8 +177,8 @@ InvokeAI:
|
||||
import logging.handlers
|
||||
import socket
|
||||
import urllib.parse
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
@ -192,36 +191,36 @@ except ImportError:
|
||||
|
||||
|
||||
# module level functions
|
||||
def debug(msg, *args, **kwargs):
|
||||
def debug(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
|
||||
InvokeAILogger.get_logger().debug(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def info(msg, *args, **kwargs):
|
||||
def info(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
|
||||
InvokeAILogger.get_logger().info(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warning(msg, *args, **kwargs):
|
||||
def warning(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
|
||||
InvokeAILogger.get_logger().warning(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def error(msg, *args, **kwargs):
|
||||
def error(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
|
||||
InvokeAILogger.get_logger().error(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def critical(msg, *args, **kwargs):
|
||||
def critical(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
|
||||
InvokeAILogger.get_logger().critical(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def log(level, msg, *args, **kwargs):
|
||||
def log(level: int, msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
|
||||
InvokeAILogger.get_logger().log(level, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def disable(level=logging.CRITICAL):
|
||||
InvokeAILogger.get_logger().disable(level)
|
||||
def disable(level: int = logging.CRITICAL) -> None: # noqa D103
|
||||
logging.disable(level)
|
||||
|
||||
|
||||
def basicConfig(**kwargs):
|
||||
InvokeAILogger.get_logger().basicConfig(**kwargs)
|
||||
def basicConfig(**kwargs: Any) -> None: # noqa D103
|
||||
logging.basicConfig(**kwargs)
|
||||
|
||||
|
||||
_FACILITY_MAP = (
|
||||
@ -256,33 +255,25 @@ _SOCK_MAP = {
|
||||
|
||||
|
||||
class InvokeAIFormatter(logging.Formatter):
|
||||
"""
|
||||
Base class for logging formatter
|
||||
"""Base class for logging formatter."""
|
||||
|
||||
"""
|
||||
|
||||
def format(self, record):
|
||||
def format(self, record: logging.LogRecord) -> str: # noqa D102
|
||||
formatter = logging.Formatter(self.log_fmt(record.levelno))
|
||||
return formatter.format(record)
|
||||
|
||||
@abstractmethod
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
pass
|
||||
def log_fmt(self, levelno: int) -> str: # noqa D102
|
||||
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
||||
|
||||
|
||||
class InvokeAISyslogFormatter(InvokeAIFormatter):
|
||||
"""
|
||||
Formatting for syslog
|
||||
"""
|
||||
"""Formatting for syslog."""
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
def log_fmt(self, levelno: int) -> str: # noqa D102
|
||||
return "%(name)s [%(process)d] <%(levelname)s> %(message)s"
|
||||
|
||||
|
||||
class InvokeAILegacyLogFormatter(InvokeAIFormatter):
|
||||
"""
|
||||
Formatting for the InvokeAI Logger (legacy version)
|
||||
"""
|
||||
class InvokeAILegacyLogFormatter(InvokeAIFormatter): # noqa D102
|
||||
"""Formatting for the InvokeAI Logger (legacy version)."""
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: " | %(message)s",
|
||||
@ -292,23 +283,21 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter):
|
||||
logging.CRITICAL: "### %(message)s",
|
||||
}
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
return self.FORMATS.get(levelno)
|
||||
def log_fmt(self, levelno: int) -> str: # noqa D102
|
||||
format = self.FORMATS.get(levelno)
|
||||
assert format is not None
|
||||
return format
|
||||
|
||||
|
||||
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
|
||||
"""
|
||||
Custom Formatting for the InvokeAI Logger (plain version)
|
||||
"""
|
||||
"""Custom Formatting for the InvokeAI Logger (plain version)."""
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
def log_fmt(self, levelno: int) -> str: # noqa D102
|
||||
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
||||
|
||||
|
||||
class InvokeAIColorLogFormatter(InvokeAIFormatter):
|
||||
"""
|
||||
Custom Formatting for the InvokeAI Logger
|
||||
"""
|
||||
"""Custom Formatting for the InvokeAI Logger."""
|
||||
|
||||
# Color Codes
|
||||
grey = "\x1b[38;20m"
|
||||
@ -331,8 +320,10 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter):
|
||||
logging.CRITICAL: bold_red + log_format + reset,
|
||||
}
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
return self.FORMATS.get(levelno)
|
||||
def log_fmt(self, levelno: int) -> str: # noqa D102
|
||||
format = self.FORMATS.get(levelno)
|
||||
assert format is not None
|
||||
return format
|
||||
|
||||
|
||||
LOG_FORMATTERS = {
|
||||
@ -343,13 +334,13 @@ LOG_FORMATTERS = {
|
||||
}
|
||||
|
||||
|
||||
class InvokeAILogger(object):
|
||||
loggers = {}
|
||||
class InvokeAILogger(object): # noqa D102
|
||||
loggers: Dict[str, logging.Logger] = {}
|
||||
|
||||
@classmethod
|
||||
def get_logger(
|
||||
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||
) -> logging.Logger:
|
||||
) -> logging.Logger: # noqa D102
|
||||
if name in cls.loggers:
|
||||
logger = cls.loggers[name]
|
||||
logger.handlers.clear()
|
||||
@ -362,7 +353,7 @@ class InvokeAILogger(object):
|
||||
return cls.loggers[name]
|
||||
|
||||
@classmethod
|
||||
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
|
||||
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]: # noqa D102
|
||||
handler_strs = config.log_handlers
|
||||
handlers = []
|
||||
for handler in handler_strs:
|
||||
@ -374,7 +365,7 @@ class InvokeAILogger(object):
|
||||
# http gets no custom formatter
|
||||
formatter = LOG_FORMATTERS[config.log_format]
|
||||
if handler_name == "console":
|
||||
ch = logging.StreamHandler()
|
||||
ch: logging.Handler = logging.StreamHandler()
|
||||
ch.setFormatter(formatter())
|
||||
handlers.append(ch)
|
||||
|
||||
@ -393,18 +384,18 @@ class InvokeAILogger(object):
|
||||
return handlers
|
||||
|
||||
@staticmethod
|
||||
def _parse_syslog_args(args: str = None) -> logging.Handler:
|
||||
def _parse_syslog_args(args: Optional[str] = None) -> logging.Handler:
|
||||
if not SYSLOG_AVAILABLE:
|
||||
raise ValueError("syslog is not available on this system")
|
||||
if not args:
|
||||
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
|
||||
syslog_args = {}
|
||||
syslog_args: Dict[str, Any] = {}
|
||||
try:
|
||||
for a in args.split(","):
|
||||
arg_name, *arg_value = a.split(":", 2)
|
||||
if arg_name == "address":
|
||||
host, *port = arg_value
|
||||
port = 514 if len(port) == 0 else int(port[0])
|
||||
host, *port_list = arg_value
|
||||
port = 514 if not port_list else int(port_list[0])
|
||||
syslog_args["address"] = (host, port)
|
||||
elif arg_name == "facility":
|
||||
syslog_args["facility"] = _FACILITY_MAP[arg_value[0]]
|
||||
@ -417,13 +408,13 @@ class InvokeAILogger(object):
|
||||
return logging.handlers.SysLogHandler(**syslog_args)
|
||||
|
||||
@staticmethod
|
||||
def _parse_file_args(args: str = None) -> logging.Handler:
|
||||
def _parse_file_args(args: Optional[str] = None) -> logging.Handler: # noqa D102
|
||||
if not args:
|
||||
raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'")
|
||||
return logging.FileHandler(args)
|
||||
|
||||
@staticmethod
|
||||
def _parse_http_args(args: str = None) -> logging.Handler:
|
||||
def _parse_http_args(args: Optional[str] = None) -> logging.Handler: # noqa D102
|
||||
if not args:
|
||||
raise ValueError("please provide destination for http logging using format 'http=url'")
|
||||
arg_list = args.split(",")
|
||||
@ -434,12 +425,12 @@ class InvokeAILogger(object):
|
||||
path = url.path
|
||||
port = url.port or 80
|
||||
|
||||
syslog_args = {}
|
||||
syslog_args: Dict[str, Any] = {}
|
||||
for a in arg_list:
|
||||
arg_name, *arg_value = a.split(":", 2)
|
||||
if arg_name == "method":
|
||||
arg_value = arg_value[0] if len(arg_value) > 0 else "GET"
|
||||
syslog_args[arg_name] = arg_value
|
||||
method = arg_value[0] if len(arg_value) > 0 else "GET"
|
||||
syslog_args[arg_name] = method
|
||||
else: # TODO: Provide support for SSL context and credentials
|
||||
pass
|
||||
return logging.handlers.HTTPHandler(f"{host}:{port}", path, **syslog_args)
|
||||
|
@ -229,8 +229,6 @@ module = [
|
||||
"invokeai.app.api.routers.models",
|
||||
"invokeai.app.invocations.compel",
|
||||
"invokeai.app.invocations.latent",
|
||||
"invokeai.app.services.config.config_base",
|
||||
"invokeai.app.services.config.config_default",
|
||||
"invokeai.app.services.invocation_stats.invocation_stats_default",
|
||||
"invokeai.app.services.model_manager.model_manager_base",
|
||||
"invokeai.app.services.model_manager.model_manager_default",
|
||||
@ -265,7 +263,6 @@ module = [
|
||||
"invokeai.backend.stable_diffusion.diffusion.cross_attention_control",
|
||||
"invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion",
|
||||
"invokeai.backend.util.hotfixes",
|
||||
"invokeai.backend.util.logging",
|
||||
"invokeai.backend.util.mps_fixes",
|
||||
"invokeai.backend.util.util",
|
||||
"invokeai.frontend.install.model_install",
|
||||
|
Loading…
Reference in New Issue
Block a user