fix a bunch of type mismatches in the logging module

This commit is contained in:
Lincoln Stein 2023-11-26 22:25:12 -05:00 committed by psychedelicious
parent e28262ebd9
commit ae82df0fda

View File

@ -1,8 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team # Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""invokeai.backend.util.logging """
Logging class for InvokeAI that produces console messages.
Logging class for InvokeAI that produces console messages
Usage: Usage:
@ -178,8 +177,8 @@ InvokeAI:
import logging.handlers import logging.handlers
import socket import socket
import urllib.parse import urllib.parse
from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
@ -192,36 +191,36 @@ except ImportError:
# module level functions # module level functions
def debug(msg, *args, **kwargs): def debug(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().debug(msg, *args, **kwargs) InvokeAILogger.get_logger().debug(msg, *args, **kwargs)
def info(msg, *args, **kwargs): def info(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().info(msg, *args, **kwargs) InvokeAILogger.get_logger().info(msg, *args, **kwargs)
def warning(msg, *args, **kwargs): def warning(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().warning(msg, *args, **kwargs) InvokeAILogger.get_logger().warning(msg, *args, **kwargs)
def error(msg, *args, **kwargs): def error(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().error(msg, *args, **kwargs) InvokeAILogger.get_logger().error(msg, *args, **kwargs)
def critical(msg, *args, **kwargs): def critical(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().critical(msg, *args, **kwargs) InvokeAILogger.get_logger().critical(msg, *args, **kwargs)
def log(level, msg, *args, **kwargs): def log(level: int, msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().log(level, msg, *args, **kwargs) InvokeAILogger.get_logger().log(level, msg, *args, **kwargs)
def disable(level=logging.CRITICAL): def disable(level: int = logging.CRITICAL) -> None: # noqa D103
InvokeAILogger.get_logger().disable(level) logging.disable(level)
def basicConfig(**kwargs): def basicConfig(**kwargs: Any) -> None: # noqa D103
InvokeAILogger.get_logger().basicConfig(**kwargs) logging.basicConfig(**kwargs)
_FACILITY_MAP = ( _FACILITY_MAP = (
@ -256,33 +255,25 @@ _SOCK_MAP = {
class InvokeAIFormatter(logging.Formatter): class InvokeAIFormatter(logging.Formatter):
""" """Base class for logging formatter."""
Base class for logging formatter
""" def format(self, record: logging.LogRecord) -> str: # noqa D102
def format(self, record):
formatter = logging.Formatter(self.log_fmt(record.levelno)) formatter = logging.Formatter(self.log_fmt(record.levelno))
return formatter.format(record) return formatter.format(record)
@abstractmethod def log_fmt(self, levelno: int) -> str: # noqa D102
def log_fmt(self, levelno: int) -> str: return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
pass
class InvokeAISyslogFormatter(InvokeAIFormatter): class InvokeAISyslogFormatter(InvokeAIFormatter):
""" """Formatting for syslog."""
Formatting for syslog
"""
def log_fmt(self, levelno: int) -> str: def log_fmt(self, levelno: int) -> str: # noqa D102
return "%(name)s [%(process)d] <%(levelname)s> %(message)s" return "%(name)s [%(process)d] <%(levelname)s> %(message)s"
class InvokeAILegacyLogFormatter(InvokeAIFormatter): class InvokeAILegacyLogFormatter(InvokeAIFormatter): # noqa D102
""" """Formatting for the InvokeAI Logger (legacy version)."""
Formatting for the InvokeAI Logger (legacy version)
"""
FORMATS = { FORMATS = {
logging.DEBUG: " | %(message)s", logging.DEBUG: " | %(message)s",
@ -292,23 +283,21 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter):
logging.CRITICAL: "### %(message)s", logging.CRITICAL: "### %(message)s",
} }
def log_fmt(self, levelno: int) -> str: def log_fmt(self, levelno: int) -> str: # noqa D102
return self.FORMATS.get(levelno) format = self.FORMATS.get(levelno)
assert format is not None
return format
class InvokeAIPlainLogFormatter(InvokeAIFormatter): class InvokeAIPlainLogFormatter(InvokeAIFormatter):
""" """Custom Formatting for the InvokeAI Logger (plain version)."""
Custom Formatting for the InvokeAI Logger (plain version)
"""
def log_fmt(self, levelno: int) -> str: def log_fmt(self, levelno: int) -> str: # noqa D102
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s" return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
class InvokeAIColorLogFormatter(InvokeAIFormatter): class InvokeAIColorLogFormatter(InvokeAIFormatter):
""" """Custom Formatting for the InvokeAI Logger."""
Custom Formatting for the InvokeAI Logger
"""
# Color Codes # Color Codes
grey = "\x1b[38;20m" grey = "\x1b[38;20m"
@ -331,8 +320,10 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter):
logging.CRITICAL: bold_red + log_format + reset, logging.CRITICAL: bold_red + log_format + reset,
} }
def log_fmt(self, levelno: int) -> str: def log_fmt(self, levelno: int) -> str: # noqa D102
return self.FORMATS.get(levelno) format = self.FORMATS.get(levelno)
assert format is not None
return format
LOG_FORMATTERS = { LOG_FORMATTERS = {
@ -343,13 +334,13 @@ LOG_FORMATTERS = {
} }
class InvokeAILogger(object): class InvokeAILogger(object): # noqa D102
loggers = {} loggers: Dict[str, logging.Logger] = {}
@classmethod @classmethod
def get_logger( def get_logger(
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config() cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
) -> logging.Logger: ) -> logging.Logger: # noqa D102
if name in cls.loggers: if name in cls.loggers:
logger = cls.loggers[name] logger = cls.loggers[name]
logger.handlers.clear() logger.handlers.clear()
@ -362,7 +353,7 @@ class InvokeAILogger(object):
return cls.loggers[name] return cls.loggers[name]
@classmethod @classmethod
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]: def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]: # noqa D102
handler_strs = config.log_handlers handler_strs = config.log_handlers
handlers = [] handlers = []
for handler in handler_strs: for handler in handler_strs:
@ -374,7 +365,7 @@ class InvokeAILogger(object):
# http gets no custom formatter # http gets no custom formatter
formatter = LOG_FORMATTERS[config.log_format] formatter = LOG_FORMATTERS[config.log_format]
if handler_name == "console": if handler_name == "console":
ch = logging.StreamHandler() ch: logging.Handler = logging.StreamHandler()
ch.setFormatter(formatter()) ch.setFormatter(formatter())
handlers.append(ch) handlers.append(ch)
@ -393,18 +384,18 @@ class InvokeAILogger(object):
return handlers return handlers
@staticmethod @staticmethod
def _parse_syslog_args(args: str = None) -> logging.Handler: def _parse_syslog_args(args: Optional[str] = None) -> logging.Handler:
if not SYSLOG_AVAILABLE: if not SYSLOG_AVAILABLE:
raise ValueError("syslog is not available on this system") raise ValueError("syslog is not available on this system")
if not args: if not args:
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514" args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
syslog_args = {} syslog_args: Dict[str, Any] = {}
try: try:
for a in args.split(","): for a in args.split(","):
arg_name, *arg_value = a.split(":", 2) arg_name, *arg_value = a.split(":", 2)
if arg_name == "address": if arg_name == "address":
host, *port = arg_value host, *port_list = arg_value
port = 514 if len(port) == 0 else int(port[0]) port = 514 if not port_list else int(port_list[0])
syslog_args["address"] = (host, port) syslog_args["address"] = (host, port)
elif arg_name == "facility": elif arg_name == "facility":
syslog_args["facility"] = _FACILITY_MAP[arg_value[0]] syslog_args["facility"] = _FACILITY_MAP[arg_value[0]]
@ -417,13 +408,13 @@ class InvokeAILogger(object):
return logging.handlers.SysLogHandler(**syslog_args) return logging.handlers.SysLogHandler(**syslog_args)
@staticmethod @staticmethod
def _parse_file_args(args: str = None) -> logging.Handler: def _parse_file_args(args: Optional[str] = None) -> logging.Handler: # noqa D102
if not args: if not args:
raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'") raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'")
return logging.FileHandler(args) return logging.FileHandler(args)
@staticmethod @staticmethod
def _parse_http_args(args: str = None) -> logging.Handler: def _parse_http_args(args: Optional[str] = None) -> logging.Handler: # noqa D102
if not args: if not args:
raise ValueError("please provide destination for http logging using format 'http=url'") raise ValueError("please provide destination for http logging using format 'http=url'")
arg_list = args.split(",") arg_list = args.split(",")
@ -434,12 +425,12 @@ class InvokeAILogger(object):
path = url.path path = url.path
port = url.port or 80 port = url.port or 80
syslog_args = {} syslog_args: Dict[str, Any] = {}
for a in arg_list: for a in arg_list:
arg_name, *arg_value = a.split(":", 2) arg_name, *arg_value = a.split(":", 2)
if arg_name == "method": if arg_name == "method":
arg_value = arg_value[0] if len(arg_value) > 0 else "GET" method = arg_value[0] if len(arg_value) > 0 else "GET"
syslog_args[arg_name] = arg_value syslog_args[arg_name] = method
else: # TODO: Provide support for SSL context and credentials else: # TODO: Provide support for SSL context and credentials
pass pass
return logging.handlers.HTTPHandler(f"{host}:{port}", path, **syslog_args) return logging.handlers.HTTPHandler(f"{host}:{port}", path, **syslog_args)