Apply black

This commit is contained in:
Martin Kristiansen
2023-07-27 10:54:01 -04:00
parent 2183dba5c5
commit 218b6d0546
148 changed files with 5486 additions and 6296 deletions

View File

@ -11,12 +11,4 @@ from .devices import (
torch_dtype,
)
from .log import write_log
from .util import (
ask_user,
download_with_resume,
instantiate_from_config,
url_attachment_name,
Chdir
)
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir

View File

@ -12,6 +12,7 @@ CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
config = InvokeAIAppConfig.get_config()
def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on"""
if config.always_use_cpu:

View File

@ -20,6 +20,7 @@ from diffusers.models.controlnet import ControlNetConditioningEmbedding, Control
# Modified ControlNetModel with encoder_attention_mask argument added
class ControlNetModel(ModelMixin, ConfigMixin):
"""
A ControlNet model.
@ -618,9 +619,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
down_block_res_samples = [torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
@ -630,5 +629,6 @@ class ControlNetModel(ModelMixin, ConfigMixin):
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
)
diffusers.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel

View File

@ -186,89 +186,109 @@ from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config
try:
import syslog
SYSLOG_AVAILABLE = True
except:
SYSLOG_AVAILABLE = False
# module level functions
def debug(msg, *args, **kwargs):
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
def info(msg, *args, **kwargs):
InvokeAILogger.getLogger().info(msg, *args, **kwargs)
def warning(msg, *args, **kwargs):
InvokeAILogger.getLogger().warning(msg, *args, **kwargs)
def error(msg, *args, **kwargs):
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
def critical(msg, *args, **kwargs):
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
def log(level, msg, *args, **kwargs):
InvokeAILogger.getLogger().log(level, msg, *args, **kwargs)
def disable(level=logging.CRITICAL):
InvokeAILogger.getLogger().disable(level)
def basicConfig(**kwargs):
InvokeAILogger.getLogger().basicConfig(**kwargs)
def getLogger(name: str = None) -> logging.Logger:
return InvokeAILogger.getLogger(name)
_FACILITY_MAP = dict(
LOG_KERN = syslog.LOG_KERN,
LOG_USER = syslog.LOG_USER,
LOG_MAIL = syslog.LOG_MAIL,
LOG_DAEMON = syslog.LOG_DAEMON,
LOG_AUTH = syslog.LOG_AUTH,
LOG_LPR = syslog.LOG_LPR,
LOG_NEWS = syslog.LOG_NEWS,
LOG_UUCP = syslog.LOG_UUCP,
LOG_CRON = syslog.LOG_CRON,
LOG_SYSLOG = syslog.LOG_SYSLOG,
LOG_LOCAL0 = syslog.LOG_LOCAL0,
LOG_LOCAL1 = syslog.LOG_LOCAL1,
LOG_LOCAL2 = syslog.LOG_LOCAL2,
LOG_LOCAL3 = syslog.LOG_LOCAL3,
LOG_LOCAL4 = syslog.LOG_LOCAL4,
LOG_LOCAL5 = syslog.LOG_LOCAL5,
LOG_LOCAL6 = syslog.LOG_LOCAL6,
LOG_LOCAL7 = syslog.LOG_LOCAL7,
) if SYSLOG_AVAILABLE else dict()
_SOCK_MAP = dict(
SOCK_STREAM = socket.SOCK_STREAM,
SOCK_DGRAM = socket.SOCK_DGRAM,
_FACILITY_MAP = (
dict(
LOG_KERN=syslog.LOG_KERN,
LOG_USER=syslog.LOG_USER,
LOG_MAIL=syslog.LOG_MAIL,
LOG_DAEMON=syslog.LOG_DAEMON,
LOG_AUTH=syslog.LOG_AUTH,
LOG_LPR=syslog.LOG_LPR,
LOG_NEWS=syslog.LOG_NEWS,
LOG_UUCP=syslog.LOG_UUCP,
LOG_CRON=syslog.LOG_CRON,
LOG_SYSLOG=syslog.LOG_SYSLOG,
LOG_LOCAL0=syslog.LOG_LOCAL0,
LOG_LOCAL1=syslog.LOG_LOCAL1,
LOG_LOCAL2=syslog.LOG_LOCAL2,
LOG_LOCAL3=syslog.LOG_LOCAL3,
LOG_LOCAL4=syslog.LOG_LOCAL4,
LOG_LOCAL5=syslog.LOG_LOCAL5,
LOG_LOCAL6=syslog.LOG_LOCAL6,
LOG_LOCAL7=syslog.LOG_LOCAL7,
)
if SYSLOG_AVAILABLE
else dict()
)
_SOCK_MAP = dict(
SOCK_STREAM=socket.SOCK_STREAM,
SOCK_DGRAM=socket.SOCK_DGRAM,
)
class InvokeAIFormatter(logging.Formatter):
'''
"""
Base class for logging formatter
'''
"""
def format(self, record):
formatter = logging.Formatter(self.log_fmt(record.levelno))
return formatter.format(record)
@abstractmethod
def log_fmt(self, levelno: int)->str:
def log_fmt(self, levelno: int) -> str:
pass
class InvokeAISyslogFormatter(InvokeAIFormatter):
'''
"""
Formatting for syslog
'''
def log_fmt(self, levelno: int)->str:
return '%(name)s [%(process)d] <%(levelname)s> %(message)s'
"""
def log_fmt(self, levelno: int) -> str:
return "%(name)s [%(process)d] <%(levelname)s> %(message)s"
class InvokeAILegacyLogFormatter(InvokeAIFormatter):
'''
"""
Formatting for the InvokeAI Logger (legacy version)
'''
"""
FORMATS = {
logging.DEBUG: " | %(message)s",
logging.INFO: ">> %(message)s",
@ -276,20 +296,25 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter):
logging.ERROR: "*** %(message)s",
logging.CRITICAL: "### %(message)s",
}
def log_fmt(self,levelno:int)->str:
def log_fmt(self, levelno: int) -> str:
return self.FORMATS.get(levelno)
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
'''
"""
Custom Formatting for the InvokeAI Logger (plain version)
'''
def log_fmt(self, levelno: int)->str:
"""
def log_fmt(self, levelno: int) -> str:
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
class InvokeAIColorLogFormatter(InvokeAIFormatter):
'''
"""
Custom Formatting for the InvokeAI Logger
'''
"""
# Color Codes
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
@ -308,32 +333,34 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter):
logging.INFO: grey + log_format + reset,
logging.WARNING: yellow + log_format + reset,
logging.ERROR: red + log_format + reset,
logging.CRITICAL: bold_red + log_format + reset
logging.CRITICAL: bold_red + log_format + reset,
}
def log_fmt(self, levelno: int)->str:
def log_fmt(self, levelno: int) -> str:
return self.FORMATS.get(levelno)
LOG_FORMATTERS = {
'plain': InvokeAIPlainLogFormatter,
'color': InvokeAIColorLogFormatter,
'syslog': InvokeAISyslogFormatter,
'legacy': InvokeAILegacyLogFormatter,
"plain": InvokeAIPlainLogFormatter,
"color": InvokeAIColorLogFormatter,
"syslog": InvokeAISyslogFormatter,
"legacy": InvokeAILegacyLogFormatter,
}
class InvokeAILogger(object):
loggers = dict()
@classmethod
def getLogger(cls,
name: str = 'InvokeAI',
config: InvokeAIAppConfig=InvokeAIAppConfig.get_config())->logging.Logger:
def getLogger(
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
) -> logging.Logger:
if name in cls.loggers:
logger = cls.loggers[name]
logger.handlers.clear()
else:
logger = logging.getLogger(name)
logger.setLevel(config.log_level.upper()) # yes, strings work here
logger.setLevel(config.log_level.upper()) # yes, strings work here
for ch in cls.getLoggers(config):
logger.addHandler(ch)
cls.loggers[name] = logger
@ -344,82 +371,80 @@ class InvokeAILogger(object):
handler_strs = config.log_handlers
handlers = list()
for handler in handler_strs:
handler_name,*args = handler.split('=',2)
handler_name, *args = handler.split("=", 2)
args = args[0] if len(args) > 0 else None
# console and file get the fancy formatter.
# syslog gets a simple one
# http gets no custom formatter
formatter = LOG_FORMATTERS[config.log_format]
if handler_name=='console':
if handler_name == "console":
ch = logging.StreamHandler()
ch.setFormatter(formatter())
handlers.append(ch)
elif handler_name=='syslog':
elif handler_name == "syslog":
ch = cls._parse_syslog_args(args)
handlers.append(ch)
elif handler_name=='file':
elif handler_name == "file":
ch = cls._parse_file_args(args)
ch.setFormatter(formatter())
handlers.append(ch)
elif handler_name=='http':
elif handler_name == "http":
ch = cls._parse_http_args(args)
handlers.append(ch)
return handlers
@staticmethod
def _parse_syslog_args(
args: str=None
)-> logging.Handler:
def _parse_syslog_args(args: str = None) -> logging.Handler:
if not SYSLOG_AVAILABLE:
raise ValueError("syslog is not available on this system")
if not args:
args='/dev/log' if Path('/dev/log').exists() else 'address:localhost:514'
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
syslog_args = dict()
try:
for a in args.split(','):
arg_name,*arg_value = a.split(':',2)
if arg_name=='address':
host,*port = arg_value
port = 514 if len(port)==0 else int(port[0])
syslog_args['address'] = (host,port)
elif arg_name=='facility':
syslog_args['facility'] = _FACILITY_MAP[arg_value[0]]
elif arg_name=='socktype':
syslog_args['socktype'] = _SOCK_MAP[arg_value[0]]
for a in args.split(","):
arg_name, *arg_value = a.split(":", 2)
if arg_name == "address":
host, *port = arg_value
port = 514 if len(port) == 0 else int(port[0])
syslog_args["address"] = (host, port)
elif arg_name == "facility":
syslog_args["facility"] = _FACILITY_MAP[arg_value[0]]
elif arg_name == "socktype":
syslog_args["socktype"] = _SOCK_MAP[arg_value[0]]
else:
syslog_args['address'] = arg_name
syslog_args["address"] = arg_name
except:
raise ValueError(f"{args} is not a value argument list for syslog logging")
return logging.handlers.SysLogHandler(**syslog_args)
@staticmethod
def _parse_file_args(args: str=None)-> logging.Handler:
def _parse_file_args(args: str = None) -> logging.Handler:
if not args:
raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'")
return logging.FileHandler(args)
@staticmethod
def _parse_http_args(args: str=None)-> logging.Handler:
def _parse_http_args(args: str = None) -> logging.Handler:
if not args:
raise ValueError("please provide destination for http logging using format 'http=url'")
arg_list = args.split(',')
arg_list = args.split(",")
url = urllib.parse.urlparse(arg_list.pop(0))
if url.scheme != 'http':
if url.scheme != "http":
raise ValueError(f"the http logging module can only log to HTTP URLs, but {url.scheme} was specified")
host = url.hostname
path = url.path
port = url.port or 80
syslog_args = dict()
for a in arg_list:
arg_name, *arg_value = a.split(':',2)
if arg_name=='method':
arg_value = arg_value[0] if len(arg_value)>0 else 'GET'
arg_name, *arg_value = a.split(":", 2)
if arg_name == "method":
arg_value = arg_value[0] if len(arg_value) > 0 else "GET"
syslog_args[arg_name] = arg_value
else: # TODO: Provide support for SSL context and credentials
pass
return logging.handlers.HTTPHandler(f'{host}:{port}',path,**syslog_args)
return logging.handlers.HTTPHandler(f"{host}:{port}", path, **syslog_args)

View File

@ -8,6 +8,8 @@ if torch.backends.mps.is_available():
_torch_layer_norm = torch.nn.functional.layer_norm
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
@ -19,20 +21,26 @@ def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
else:
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
torch.nn.functional.layer_norm = new_layer_norm
_torch_tensor_permute = torch.Tensor.permute
def new_torch_tensor_permute(input, *dims):
result = _torch_tensor_permute(input, *dims)
if input.device == "mps" and input.dtype == torch.float16:
result = result.contiguous()
return result
torch.Tensor.permute = new_torch_tensor_permute
_torch_lerp = torch.lerp
def new_torch_lerp(input, end, weight, *, out=None):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
@ -52,20 +60,36 @@ def new_torch_lerp(input, end, weight, *, out=None):
else:
return _torch_lerp(input, end, weight, out=out)
torch.lerp = new_torch_lerp
_torch_interpolate = torch.nn.functional.interpolate
def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
def new_torch_interpolate(
input,
size=None,
scale_factor=None,
mode="nearest",
align_corners=None,
recompute_scale_factor=None,
antialias=False,
):
if input.device.type == "mps" and input.dtype == torch.float16:
return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half()
return _torch_interpolate(
input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
).half()
else:
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
torch.nn.functional.interpolate = new_torch_interpolate
# TODO: refactor it
_SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor
class ChunkedSlicedAttnProcessor:
r"""
Processor for implementing sliced attention.
@ -78,7 +102,7 @@ class ChunkedSlicedAttnProcessor:
def __init__(self, slice_size):
assert isinstance(slice_size, int)
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
self.slice_size = slice_size
self._sliced_attn_processor = _SlicedAttnProcessor(slice_size)
@ -121,7 +145,9 @@ class ChunkedSlicedAttnProcessor:
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
chunk_tmp_tensor = torch.empty(self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device)
chunk_tmp_tensor = torch.empty(
self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
@ -131,7 +157,15 @@ class ChunkedSlicedAttnProcessor:
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
self.get_attention_scores_chunked(attn, query_slice, key_slice, attn_mask_slice, hidden_states[start_idx:end_idx], value[start_idx:end_idx], chunk_tmp_tensor)
self.get_attention_scores_chunked(
attn,
query_slice,
key_slice,
attn_mask_slice,
hidden_states[start_idx:end_idx],
value[start_idx:end_idx],
chunk_tmp_tensor,
)
hidden_states = attn.batch_to_head_dim(hidden_states)
@ -150,7 +184,6 @@ class ChunkedSlicedAttnProcessor:
return hidden_states
def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk):
# batch size = 1
assert query.shape[0] == 1
@ -163,14 +196,14 @@ class ChunkedSlicedAttnProcessor:
query = query.float()
key = key.float()
#out_item_size = query.dtype.itemsize
#if attn.upcast_attention:
# out_item_size = query.dtype.itemsize
# if attn.upcast_attention:
# out_item_size = torch.float32.itemsize
out_item_size = query.element_size()
if attn.upcast_attention:
out_item_size = 4
chunk_size = 2 ** 29
chunk_size = 2**29
out_size = query.shape[1] * key.shape[1] * out_item_size
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
@ -181,8 +214,8 @@ class ChunkedSlicedAttnProcessor:
def _get_chunk_view(tensor, start, length):
if start + length > tensor.shape[1]:
length = tensor.shape[1] - start
#print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
return tensor[:,start:start+length]
# print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
return tensor[:, start : start + length]
for chunk_pos in range(0, query.shape[1], chunk_step):
if attention_mask is not None:
@ -196,7 +229,7 @@ class ChunkedSlicedAttnProcessor:
)
else:
torch.baddbmm(
torch.zeros((1,1,1), device=query.device, dtype=query.dtype),
torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype),
_get_chunk_view(query, chunk_pos, chunk_step),
key,
beta=0,
@ -206,7 +239,7 @@ class ChunkedSlicedAttnProcessor:
chunk = chunk.softmax(dim=-1)
torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step))
#del chunk
# del chunk
diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor

View File

@ -32,9 +32,7 @@ def log_txt_as_img(wh, xc, size=10):
draw = ImageDraw.Draw(txt)
font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256))
lines = "\n".join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
@ -81,9 +79,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
logger.debug(
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
)
logger.debug(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
@ -154,21 +150,12 @@ def parallel_data_prefetch(
proc = Thread
# spawn processes
if target_data_type == "ndarray":
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
]
arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))]
else:
step = (
int(len(data) / n_proc + 1)
if len(data) % n_proc != 0
else int(len(data) / n_proc)
)
step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc)
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(
[data[i : i + step] for i in range(0, len(data), step)]
)
for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)])
]
processes = []
for i in range(n_proc):
@ -220,9 +207,7 @@ def parallel_data_prefetch(
return gather_res
def rand_perlin_2d(
shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
):
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
@ -265,9 +250,9 @@ def rand_perlin_2d(
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device)
t = fade(grid[: shape[0], : shape[1]])
noise = math.sqrt(2) * torch.lerp(
torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
).to(device)
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(
device
)
return noise.to(dtype=torch_dtype(device))
@ -276,9 +261,7 @@ def ask_user(question: str, answers: list):
user_prompt = f"\n>> {question} {answers}: "
invalid_answer_msg = "Invalid answer. Please try again."
pose_question = chain(
[user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt]))
)
pose_question = chain([user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt])))
user_answers = map(input, pose_question)
valid_response = next(filter(answers.__contains__, user_answers))
return valid_response
@ -303,9 +286,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
if dest.is_dir():
try:
file_name = re.search(
'filename="(.+)"', resp.headers.get("Content-Disposition")
).group(1)
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
except:
file_name = os.path.basename(url)
dest = dest / file_name
@ -322,7 +303,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
logger.warning("corrupt existing file found. re-downloading")
os.remove(dest)
exist_size = 0
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
logger.warning(f"{dest}: complete file found. Skipping.")
return dest
@ -377,16 +358,16 @@ def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
buffered = io.BytesIO()
image.save(buffered, format=image_format)
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
image_base64 = f"data:{mime_type};base64," + base64.b64encode(buffered.getvalue()).decode("UTF-8")
return image_base64
class Chdir(object):
'''Context manager to chdir to desired directory and change back after context exits:
"""Context manager to chdir to desired directory and change back after context exits:
Args:
path (Path): The path to the cwd
'''
"""
def __init__(self, path: Path):
self.path = path
self.original = Path().absolute()
@ -394,5 +375,5 @@ class Chdir(object):
def __enter__(self):
os.chdir(self.path)
def __exit__(self,*args):
def __exit__(self, *args):
os.chdir(self.original)