mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/config-management-fixes
This commit is contained in:
@ -35,15 +35,19 @@ from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
from invokeai.app.services.config import (
|
||||
get_invokeai_config,
|
||||
InvokeAIAppConfig,
|
||||
)
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.widgets import (
|
||||
CenteredButtonPress,
|
||||
IntTitleSlider,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
|
||||
from invokeai.backend.config.legacy_arg_parsing import legacy_parser
|
||||
from invokeai.backend.config.model_install_backend import (
|
||||
default_dataset,
|
||||
@ -51,6 +55,7 @@ from invokeai.backend.config.model_install_backend import (
|
||||
hf_download_with_resume,
|
||||
recommended_datasets,
|
||||
)
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
@ -59,6 +64,7 @@ transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
Model_dir = "models"
|
||||
@ -817,6 +823,7 @@ def main():
|
||||
if old_init_file.exists() and not new_init_file.exists():
|
||||
print('** Migrating invokeai.init to invokeai.yaml')
|
||||
migrate_init_file(old_init_file)
|
||||
|
||||
# Load new init file into config
|
||||
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
||||
|
||||
|
@ -28,6 +28,7 @@ warnings.filterwarnings("ignore")
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
|
@ -39,8 +39,8 @@ def get_uc_and_c_and_ec(prompt_string,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
# get rid of any newline characters
|
||||
prompt_string = prompt_string.replace("\n", " ")
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
@ -282,6 +282,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
|
||||
for match in re.finditer(prompt_parser, text)
|
||||
]
|
||||
if len(parsed_prompts) == 0:
|
||||
return []
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
|
@ -17,3 +17,5 @@ from .util import (
|
||||
instantiate_from_config,
|
||||
url_attachment_name,
|
||||
)
|
||||
|
||||
|
||||
|
@ -31,7 +31,20 @@ IAILogger.debug('this is a debugging message')
|
||||
"""
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import socket
|
||||
import urllib.parse
|
||||
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config
|
||||
|
||||
try:
|
||||
import syslog
|
||||
SYSLOG_AVAILABLE = True
|
||||
except:
|
||||
SYSLOG_AVAILABLE = False
|
||||
|
||||
# module level functions
|
||||
def debug(msg, *args, **kwargs):
|
||||
@ -62,11 +75,77 @@ def getLogger(name: str = None) -> logging.Logger:
|
||||
return InvokeAILogger.getLogger(name)
|
||||
|
||||
|
||||
class InvokeAILogFormatter(logging.Formatter):
|
||||
_FACILITY_MAP = dict(
|
||||
LOG_KERN = syslog.LOG_KERN,
|
||||
LOG_USER = syslog.LOG_USER,
|
||||
LOG_MAIL = syslog.LOG_MAIL,
|
||||
LOG_DAEMON = syslog.LOG_DAEMON,
|
||||
LOG_AUTH = syslog.LOG_AUTH,
|
||||
LOG_LPR = syslog.LOG_LPR,
|
||||
LOG_NEWS = syslog.LOG_NEWS,
|
||||
LOG_UUCP = syslog.LOG_UUCP,
|
||||
LOG_CRON = syslog.LOG_CRON,
|
||||
LOG_SYSLOG = syslog.LOG_SYSLOG,
|
||||
LOG_LOCAL0 = syslog.LOG_LOCAL0,
|
||||
LOG_LOCAL1 = syslog.LOG_LOCAL1,
|
||||
LOG_LOCAL2 = syslog.LOG_LOCAL2,
|
||||
LOG_LOCAL3 = syslog.LOG_LOCAL3,
|
||||
LOG_LOCAL4 = syslog.LOG_LOCAL4,
|
||||
LOG_LOCAL5 = syslog.LOG_LOCAL5,
|
||||
LOG_LOCAL6 = syslog.LOG_LOCAL6,
|
||||
LOG_LOCAL7 = syslog.LOG_LOCAL7,
|
||||
) if SYSLOG_AVAILABLE else dict()
|
||||
|
||||
_SOCK_MAP = dict(
|
||||
SOCK_STREAM = socket.SOCK_STREAM,
|
||||
SOCK_DGRAM = socket.SOCK_DGRAM,
|
||||
)
|
||||
|
||||
class InvokeAIFormatter(logging.Formatter):
|
||||
'''
|
||||
Base class for logging formatter
|
||||
|
||||
'''
|
||||
def format(self, record):
|
||||
formatter = logging.Formatter(self.log_fmt(record.levelno))
|
||||
return formatter.format(record)
|
||||
|
||||
@abstractmethod
|
||||
def log_fmt(self, levelno: int)->str:
|
||||
pass
|
||||
|
||||
class InvokeAISyslogFormatter(InvokeAIFormatter):
|
||||
'''
|
||||
Formatting for syslog
|
||||
'''
|
||||
def log_fmt(self, levelno: int)->str:
|
||||
return '%(name)s [%(process)d] <%(levelname)s> %(message)s'
|
||||
|
||||
class InvokeAILegacyLogFormatter(InvokeAIFormatter):
|
||||
'''
|
||||
Formatting for the InvokeAI Logger (legacy version)
|
||||
'''
|
||||
FORMATS = {
|
||||
logging.DEBUG: " | %(message)s",
|
||||
logging.INFO: ">> %(message)s",
|
||||
logging.WARNING: "** %(message)s",
|
||||
logging.ERROR: "*** %(message)s",
|
||||
logging.CRITICAL: "### %(message)s",
|
||||
}
|
||||
def log_fmt(self,levelno:int)->str:
|
||||
return self.FORMATS.get(levelno)
|
||||
|
||||
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
|
||||
'''
|
||||
Custom Formatting for the InvokeAI Logger (plain version)
|
||||
'''
|
||||
def log_fmt(self, levelno: int)->str:
|
||||
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
||||
|
||||
class InvokeAIColorLogFormatter(InvokeAIFormatter):
|
||||
'''
|
||||
Custom Formatting for the InvokeAI Logger
|
||||
'''
|
||||
|
||||
# Color Codes
|
||||
grey = "\x1b[38;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
@ -88,23 +167,109 @@ class InvokeAILogFormatter(logging.Formatter):
|
||||
logging.CRITICAL: bold_red + log_format + reset
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt, datefmt="%d-%m-%Y %H:%M:%S")
|
||||
return formatter.format(record)
|
||||
def log_fmt(self, levelno: int)->str:
|
||||
return self.FORMATS.get(levelno)
|
||||
|
||||
LOG_FORMATTERS = {
|
||||
'plain': InvokeAIPlainLogFormatter,
|
||||
'color': InvokeAIColorLogFormatter,
|
||||
'syslog': InvokeAISyslogFormatter,
|
||||
'legacy': InvokeAILegacyLogFormatter,
|
||||
}
|
||||
|
||||
class InvokeAILogger(object):
|
||||
loggers = dict()
|
||||
|
||||
@classmethod
|
||||
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger:
|
||||
config = get_invokeai_config()
|
||||
|
||||
if name not in cls.loggers:
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
ch = logging.StreamHandler()
|
||||
fmt = InvokeAILogFormatter()
|
||||
ch.setFormatter(fmt)
|
||||
logger.addHandler(ch)
|
||||
logger.setLevel(config.log_level.upper()) # yes, strings work here
|
||||
for ch in cls.getLoggers(config):
|
||||
logger.addHandler(ch)
|
||||
cls.loggers[name] = logger
|
||||
return cls.loggers[name]
|
||||
|
||||
@classmethod
|
||||
def getLoggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
|
||||
handler_strs = config.log_handlers
|
||||
print(f'handler_strs={handler_strs}')
|
||||
handlers = list()
|
||||
for handler in handler_strs:
|
||||
handler_name,*args = handler.split('=',2)
|
||||
args = args[0] if len(args) > 0 else None
|
||||
|
||||
# console is the only handler that gets a custom formatter
|
||||
if handler_name=='console':
|
||||
formatter = LOG_FORMATTERS[config.log_format]
|
||||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(formatter())
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name=='syslog':
|
||||
ch = cls._parse_syslog_args(args)
|
||||
ch.setFormatter(InvokeAISyslogFormatter())
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name=='file':
|
||||
handlers.append(cls._parse_file_args(args))
|
||||
|
||||
elif handler_name=='http':
|
||||
handlers.append(cls._parse_http_args(args))
|
||||
return handlers
|
||||
|
||||
@staticmethod
|
||||
def _parse_syslog_args(
|
||||
args: str=None
|
||||
)-> logging.Handler:
|
||||
if not SYSLOG_AVAILABLE:
|
||||
raise ValueError("syslog is not available on this system")
|
||||
if not args:
|
||||
args='/dev/log' if Path('/dev/log').exists() else 'address:localhost:514'
|
||||
syslog_args = dict()
|
||||
try:
|
||||
for a in args.split(','):
|
||||
arg_name,*arg_value = a.split(':',2)
|
||||
if arg_name=='address':
|
||||
host,*port = arg_value
|
||||
port = 514 if len(port)==0 else int(port[0])
|
||||
syslog_args['address'] = (host,port)
|
||||
elif arg_name=='facility':
|
||||
syslog_args['facility'] = _FACILITY_MAP[arg_value[0]]
|
||||
elif arg_name=='socktype':
|
||||
syslog_args['socktype'] = _SOCK_MAP[arg_value[0]]
|
||||
else:
|
||||
syslog_args['address'] = arg_name
|
||||
except:
|
||||
raise ValueError(f"{args} is not a value argument list for syslog logging")
|
||||
return logging.handlers.SysLogHandler(**syslog_args)
|
||||
|
||||
@staticmethod
|
||||
def _parse_file_args(args: str=None)-> logging.Handler:
|
||||
if not args:
|
||||
raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'")
|
||||
return logging.FileHandler(args)
|
||||
|
||||
@staticmethod
|
||||
def _parse_http_args(args: str=None)-> logging.Handler:
|
||||
if not args:
|
||||
raise ValueError("please provide destination for http logging using format 'http=url'")
|
||||
arg_list = args.split(',')
|
||||
url = urllib.parse.urlparse(arg_list.pop(0))
|
||||
if url.scheme != 'http':
|
||||
raise ValueError(f"the http logging module can only log to HTTP URLs, but {url.scheme} was specified")
|
||||
host = url.hostname
|
||||
path = url.path
|
||||
port = url.port or 80
|
||||
|
||||
syslog_args = dict()
|
||||
for a in arg_list:
|
||||
arg_name, *arg_value = a.split(':',2)
|
||||
if arg_name=='method':
|
||||
arg_value = arg_value[0] if len(arg_value)>0 else 'GET'
|
||||
syslog_args[arg_name] = arg_value
|
||||
else: # TODO: Provide support for SSL context and credentials
|
||||
pass
|
||||
return logging.handlers.HTTPHandler(f'{host}:{port}',path,**syslog_args)
|
||||
|
Reference in New Issue
Block a user