Merge branch 'main' into lstein/config-management-fixes

This commit is contained in:
Lincoln Stein
2023-06-04 18:17:43 -04:00
committed by GitHub
140 changed files with 4647 additions and 695 deletions

View File

@ -35,15 +35,19 @@ from transformers import (
CLIPTextModel,
CLIPTokenizer,
)
import invokeai.configs as configs
from invokeai.app.services.config import (
get_invokeai_config,
InvokeAIAppConfig,
)
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import (
CenteredButtonPress,
IntTitleSlider,
set_min_terminal_size,
)
from invokeai.backend.config.legacy_arg_parsing import legacy_parser
from invokeai.backend.config.model_install_backend import (
default_dataset,
@ -51,6 +55,7 @@ from invokeai.backend.config.model_install_backend import (
hf_download_with_resume,
recommended_datasets,
)
from invokeai.app.services.config import InvokeAIAppConfig
warnings.filterwarnings("ignore")
@ -59,6 +64,7 @@ transformers.logging.set_verbosity_error()
# --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config()
Model_dir = "models"
@ -817,6 +823,7 @@ def main():
if old_init_file.exists() and not new_init_file.exists():
print('** Migrating invokeai.init to invokeai.yaml')
migrate_init_file(old_init_file)
# Load new init file into config
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))

View File

@ -28,6 +28,7 @@ warnings.filterwarnings("ignore")
# --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config()
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"

View File

@ -39,8 +39,8 @@ def get_uc_and_c_and_ec(prompt_string,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
)
# get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ")
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
@ -282,6 +282,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
for match in re.finditer(prompt_parser, text)
]
if len(parsed_prompts) == 0:
return []
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))

View File

@ -17,3 +17,5 @@ from .util import (
instantiate_from_config,
url_attachment_name,
)

View File

@ -31,7 +31,20 @@ IAILogger.debug('this is a debugging message')
"""
import logging
import logging.handlers
import socket
import urllib.parse
from abc import abstractmethod
from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config
try:
import syslog
SYSLOG_AVAILABLE = True
except:
SYSLOG_AVAILABLE = False
# module level functions
def debug(msg, *args, **kwargs):
@ -62,11 +75,77 @@ def getLogger(name: str = None) -> logging.Logger:
return InvokeAILogger.getLogger(name)
class InvokeAILogFormatter(logging.Formatter):
_FACILITY_MAP = dict(
LOG_KERN = syslog.LOG_KERN,
LOG_USER = syslog.LOG_USER,
LOG_MAIL = syslog.LOG_MAIL,
LOG_DAEMON = syslog.LOG_DAEMON,
LOG_AUTH = syslog.LOG_AUTH,
LOG_LPR = syslog.LOG_LPR,
LOG_NEWS = syslog.LOG_NEWS,
LOG_UUCP = syslog.LOG_UUCP,
LOG_CRON = syslog.LOG_CRON,
LOG_SYSLOG = syslog.LOG_SYSLOG,
LOG_LOCAL0 = syslog.LOG_LOCAL0,
LOG_LOCAL1 = syslog.LOG_LOCAL1,
LOG_LOCAL2 = syslog.LOG_LOCAL2,
LOG_LOCAL3 = syslog.LOG_LOCAL3,
LOG_LOCAL4 = syslog.LOG_LOCAL4,
LOG_LOCAL5 = syslog.LOG_LOCAL5,
LOG_LOCAL6 = syslog.LOG_LOCAL6,
LOG_LOCAL7 = syslog.LOG_LOCAL7,
) if SYSLOG_AVAILABLE else dict()
_SOCK_MAP = dict(
SOCK_STREAM = socket.SOCK_STREAM,
SOCK_DGRAM = socket.SOCK_DGRAM,
)
class InvokeAIFormatter(logging.Formatter):
'''
Base class for logging formatter
'''
def format(self, record):
formatter = logging.Formatter(self.log_fmt(record.levelno))
return formatter.format(record)
@abstractmethod
def log_fmt(self, levelno: int)->str:
pass
class InvokeAISyslogFormatter(InvokeAIFormatter):
'''
Formatting for syslog
'''
def log_fmt(self, levelno: int)->str:
return '%(name)s [%(process)d] <%(levelname)s> %(message)s'
class InvokeAILegacyLogFormatter(InvokeAIFormatter):
'''
Formatting for the InvokeAI Logger (legacy version)
'''
FORMATS = {
logging.DEBUG: " | %(message)s",
logging.INFO: ">> %(message)s",
logging.WARNING: "** %(message)s",
logging.ERROR: "*** %(message)s",
logging.CRITICAL: "### %(message)s",
}
def log_fmt(self,levelno:int)->str:
return self.FORMATS.get(levelno)
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
'''
Custom Formatting for the InvokeAI Logger (plain version)
'''
def log_fmt(self, levelno: int)->str:
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
class InvokeAIColorLogFormatter(InvokeAIFormatter):
'''
Custom Formatting for the InvokeAI Logger
'''
# Color Codes
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
@ -88,23 +167,109 @@ class InvokeAILogFormatter(logging.Formatter):
logging.CRITICAL: bold_red + log_format + reset
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, datefmt="%d-%m-%Y %H:%M:%S")
return formatter.format(record)
def log_fmt(self, levelno: int)->str:
return self.FORMATS.get(levelno)
LOG_FORMATTERS = {
'plain': InvokeAIPlainLogFormatter,
'color': InvokeAIColorLogFormatter,
'syslog': InvokeAISyslogFormatter,
'legacy': InvokeAILegacyLogFormatter,
}
class InvokeAILogger(object):
loggers = dict()
@classmethod
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger:
config = get_invokeai_config()
if name not in cls.loggers:
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
fmt = InvokeAILogFormatter()
ch.setFormatter(fmt)
logger.addHandler(ch)
logger.setLevel(config.log_level.upper()) # yes, strings work here
for ch in cls.getLoggers(config):
logger.addHandler(ch)
cls.loggers[name] = logger
return cls.loggers[name]
@classmethod
def getLoggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
handler_strs = config.log_handlers
print(f'handler_strs={handler_strs}')
handlers = list()
for handler in handler_strs:
handler_name,*args = handler.split('=',2)
args = args[0] if len(args) > 0 else None
# console is the only handler that gets a custom formatter
if handler_name=='console':
formatter = LOG_FORMATTERS[config.log_format]
ch = logging.StreamHandler()
ch.setFormatter(formatter())
handlers.append(ch)
elif handler_name=='syslog':
ch = cls._parse_syslog_args(args)
ch.setFormatter(InvokeAISyslogFormatter())
handlers.append(ch)
elif handler_name=='file':
handlers.append(cls._parse_file_args(args))
elif handler_name=='http':
handlers.append(cls._parse_http_args(args))
return handlers
@staticmethod
def _parse_syslog_args(
args: str=None
)-> logging.Handler:
if not SYSLOG_AVAILABLE:
raise ValueError("syslog is not available on this system")
if not args:
args='/dev/log' if Path('/dev/log').exists() else 'address:localhost:514'
syslog_args = dict()
try:
for a in args.split(','):
arg_name,*arg_value = a.split(':',2)
if arg_name=='address':
host,*port = arg_value
port = 514 if len(port)==0 else int(port[0])
syslog_args['address'] = (host,port)
elif arg_name=='facility':
syslog_args['facility'] = _FACILITY_MAP[arg_value[0]]
elif arg_name=='socktype':
syslog_args['socktype'] = _SOCK_MAP[arg_value[0]]
else:
syslog_args['address'] = arg_name
except:
raise ValueError(f"{args} is not a value argument list for syslog logging")
return logging.handlers.SysLogHandler(**syslog_args)
@staticmethod
def _parse_file_args(args: str=None)-> logging.Handler:
if not args:
raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'")
return logging.FileHandler(args)
@staticmethod
def _parse_http_args(args: str=None)-> logging.Handler:
if not args:
raise ValueError("please provide destination for http logging using format 'http=url'")
arg_list = args.split(',')
url = urllib.parse.urlparse(arg_list.pop(0))
if url.scheme != 'http':
raise ValueError(f"the http logging module can only log to HTTP URLs, but {url.scheme} was specified")
host = url.hostname
path = url.path
port = url.port or 80
syslog_args = dict()
for a in arg_list:
arg_name, *arg_value = a.split(':',2)
if arg_name=='method':
arg_value = arg_value[0] if len(arg_value)>0 else 'GET'
syslog_args[arg_name] = arg_value
else: # TODO: Provide support for SSL context and credentials
pass
return logging.handlers.HTTPHandler(f'{host}:{port}',path,**syslog_args)