mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply black
This commit is contained in:
@ -11,12 +11,4 @@ from .devices import (
|
||||
torch_dtype,
|
||||
)
|
||||
from .log import write_log
|
||||
from .util import (
|
||||
ask_user,
|
||||
download_with_resume,
|
||||
instantiate_from_config,
|
||||
url_attachment_name,
|
||||
Chdir
|
||||
)
|
||||
|
||||
|
||||
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir
|
||||
|
@ -12,6 +12,7 @@ CUDA_DEVICE = torch.device("cuda")
|
||||
MPS_DEVICE = torch.device("mps")
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
|
||||
def choose_torch_device() -> torch.device:
|
||||
"""Convenience routine for guessing which GPU device to run model on"""
|
||||
if config.always_use_cpu:
|
||||
|
@ -20,6 +20,7 @@ from diffusers.models.controlnet import ControlNetConditioningEmbedding, Control
|
||||
|
||||
# Modified ControlNetModel with encoder_attention_mask argument added
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
@ -618,9 +619,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
down_block_res_samples = [torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
@ -630,5 +629,6 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||
)
|
||||
|
||||
|
||||
diffusers.ControlNetModel = ControlNetModel
|
||||
diffusers.models.controlnet.ControlNetModel = ControlNetModel
|
||||
diffusers.models.controlnet.ControlNetModel = ControlNetModel
|
||||
|
@ -186,89 +186,109 @@ from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config
|
||||
|
||||
try:
|
||||
import syslog
|
||||
|
||||
SYSLOG_AVAILABLE = True
|
||||
except:
|
||||
SYSLOG_AVAILABLE = False
|
||||
|
||||
|
||||
# module level functions
|
||||
def debug(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def info(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().info(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warning(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().warning(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def error(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def critical(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def log(level, msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().log(level, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def disable(level=logging.CRITICAL):
|
||||
InvokeAILogger.getLogger().disable(level)
|
||||
|
||||
|
||||
def basicConfig(**kwargs):
|
||||
InvokeAILogger.getLogger().basicConfig(**kwargs)
|
||||
|
||||
|
||||
def getLogger(name: str = None) -> logging.Logger:
|
||||
return InvokeAILogger.getLogger(name)
|
||||
|
||||
|
||||
_FACILITY_MAP = dict(
|
||||
LOG_KERN = syslog.LOG_KERN,
|
||||
LOG_USER = syslog.LOG_USER,
|
||||
LOG_MAIL = syslog.LOG_MAIL,
|
||||
LOG_DAEMON = syslog.LOG_DAEMON,
|
||||
LOG_AUTH = syslog.LOG_AUTH,
|
||||
LOG_LPR = syslog.LOG_LPR,
|
||||
LOG_NEWS = syslog.LOG_NEWS,
|
||||
LOG_UUCP = syslog.LOG_UUCP,
|
||||
LOG_CRON = syslog.LOG_CRON,
|
||||
LOG_SYSLOG = syslog.LOG_SYSLOG,
|
||||
LOG_LOCAL0 = syslog.LOG_LOCAL0,
|
||||
LOG_LOCAL1 = syslog.LOG_LOCAL1,
|
||||
LOG_LOCAL2 = syslog.LOG_LOCAL2,
|
||||
LOG_LOCAL3 = syslog.LOG_LOCAL3,
|
||||
LOG_LOCAL4 = syslog.LOG_LOCAL4,
|
||||
LOG_LOCAL5 = syslog.LOG_LOCAL5,
|
||||
LOG_LOCAL6 = syslog.LOG_LOCAL6,
|
||||
LOG_LOCAL7 = syslog.LOG_LOCAL7,
|
||||
) if SYSLOG_AVAILABLE else dict()
|
||||
|
||||
_SOCK_MAP = dict(
|
||||
SOCK_STREAM = socket.SOCK_STREAM,
|
||||
SOCK_DGRAM = socket.SOCK_DGRAM,
|
||||
_FACILITY_MAP = (
|
||||
dict(
|
||||
LOG_KERN=syslog.LOG_KERN,
|
||||
LOG_USER=syslog.LOG_USER,
|
||||
LOG_MAIL=syslog.LOG_MAIL,
|
||||
LOG_DAEMON=syslog.LOG_DAEMON,
|
||||
LOG_AUTH=syslog.LOG_AUTH,
|
||||
LOG_LPR=syslog.LOG_LPR,
|
||||
LOG_NEWS=syslog.LOG_NEWS,
|
||||
LOG_UUCP=syslog.LOG_UUCP,
|
||||
LOG_CRON=syslog.LOG_CRON,
|
||||
LOG_SYSLOG=syslog.LOG_SYSLOG,
|
||||
LOG_LOCAL0=syslog.LOG_LOCAL0,
|
||||
LOG_LOCAL1=syslog.LOG_LOCAL1,
|
||||
LOG_LOCAL2=syslog.LOG_LOCAL2,
|
||||
LOG_LOCAL3=syslog.LOG_LOCAL3,
|
||||
LOG_LOCAL4=syslog.LOG_LOCAL4,
|
||||
LOG_LOCAL5=syslog.LOG_LOCAL5,
|
||||
LOG_LOCAL6=syslog.LOG_LOCAL6,
|
||||
LOG_LOCAL7=syslog.LOG_LOCAL7,
|
||||
)
|
||||
if SYSLOG_AVAILABLE
|
||||
else dict()
|
||||
)
|
||||
|
||||
_SOCK_MAP = dict(
|
||||
SOCK_STREAM=socket.SOCK_STREAM,
|
||||
SOCK_DGRAM=socket.SOCK_DGRAM,
|
||||
)
|
||||
|
||||
|
||||
class InvokeAIFormatter(logging.Formatter):
|
||||
'''
|
||||
"""
|
||||
Base class for logging formatter
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def format(self, record):
|
||||
formatter = logging.Formatter(self.log_fmt(record.levelno))
|
||||
return formatter.format(record)
|
||||
|
||||
@abstractmethod
|
||||
def log_fmt(self, levelno: int)->str:
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class InvokeAISyslogFormatter(InvokeAIFormatter):
|
||||
'''
|
||||
"""
|
||||
Formatting for syslog
|
||||
'''
|
||||
def log_fmt(self, levelno: int)->str:
|
||||
return '%(name)s [%(process)d] <%(levelname)s> %(message)s'
|
||||
"""
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
return "%(name)s [%(process)d] <%(levelname)s> %(message)s"
|
||||
|
||||
|
||||
class InvokeAILegacyLogFormatter(InvokeAIFormatter):
|
||||
'''
|
||||
"""
|
||||
Formatting for the InvokeAI Logger (legacy version)
|
||||
'''
|
||||
"""
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: " | %(message)s",
|
||||
logging.INFO: ">> %(message)s",
|
||||
@ -276,20 +296,25 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter):
|
||||
logging.ERROR: "*** %(message)s",
|
||||
logging.CRITICAL: "### %(message)s",
|
||||
}
|
||||
def log_fmt(self,levelno:int)->str:
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
return self.FORMATS.get(levelno)
|
||||
|
||||
|
||||
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
|
||||
'''
|
||||
"""
|
||||
Custom Formatting for the InvokeAI Logger (plain version)
|
||||
'''
|
||||
def log_fmt(self, levelno: int)->str:
|
||||
"""
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
||||
|
||||
|
||||
class InvokeAIColorLogFormatter(InvokeAIFormatter):
|
||||
'''
|
||||
"""
|
||||
Custom Formatting for the InvokeAI Logger
|
||||
'''
|
||||
"""
|
||||
|
||||
# Color Codes
|
||||
grey = "\x1b[38;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
@ -308,32 +333,34 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter):
|
||||
logging.INFO: grey + log_format + reset,
|
||||
logging.WARNING: yellow + log_format + reset,
|
||||
logging.ERROR: red + log_format + reset,
|
||||
logging.CRITICAL: bold_red + log_format + reset
|
||||
logging.CRITICAL: bold_red + log_format + reset,
|
||||
}
|
||||
|
||||
def log_fmt(self, levelno: int)->str:
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
return self.FORMATS.get(levelno)
|
||||
|
||||
|
||||
LOG_FORMATTERS = {
|
||||
'plain': InvokeAIPlainLogFormatter,
|
||||
'color': InvokeAIColorLogFormatter,
|
||||
'syslog': InvokeAISyslogFormatter,
|
||||
'legacy': InvokeAILegacyLogFormatter,
|
||||
"plain": InvokeAIPlainLogFormatter,
|
||||
"color": InvokeAIColorLogFormatter,
|
||||
"syslog": InvokeAISyslogFormatter,
|
||||
"legacy": InvokeAILegacyLogFormatter,
|
||||
}
|
||||
|
||||
|
||||
class InvokeAILogger(object):
|
||||
loggers = dict()
|
||||
|
||||
@classmethod
|
||||
def getLogger(cls,
|
||||
name: str = 'InvokeAI',
|
||||
config: InvokeAIAppConfig=InvokeAIAppConfig.get_config())->logging.Logger:
|
||||
def getLogger(
|
||||
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||
) -> logging.Logger:
|
||||
if name in cls.loggers:
|
||||
logger = cls.loggers[name]
|
||||
logger.handlers.clear()
|
||||
else:
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(config.log_level.upper()) # yes, strings work here
|
||||
logger.setLevel(config.log_level.upper()) # yes, strings work here
|
||||
for ch in cls.getLoggers(config):
|
||||
logger.addHandler(ch)
|
||||
cls.loggers[name] = logger
|
||||
@ -344,82 +371,80 @@ class InvokeAILogger(object):
|
||||
handler_strs = config.log_handlers
|
||||
handlers = list()
|
||||
for handler in handler_strs:
|
||||
handler_name,*args = handler.split('=',2)
|
||||
handler_name, *args = handler.split("=", 2)
|
||||
args = args[0] if len(args) > 0 else None
|
||||
|
||||
# console and file get the fancy formatter.
|
||||
# syslog gets a simple one
|
||||
# http gets no custom formatter
|
||||
formatter = LOG_FORMATTERS[config.log_format]
|
||||
if handler_name=='console':
|
||||
if handler_name == "console":
|
||||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(formatter())
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name=='syslog':
|
||||
|
||||
elif handler_name == "syslog":
|
||||
ch = cls._parse_syslog_args(args)
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name=='file':
|
||||
|
||||
elif handler_name == "file":
|
||||
ch = cls._parse_file_args(args)
|
||||
ch.setFormatter(formatter())
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name=='http':
|
||||
|
||||
elif handler_name == "http":
|
||||
ch = cls._parse_http_args(args)
|
||||
handlers.append(ch)
|
||||
return handlers
|
||||
|
||||
@staticmethod
|
||||
def _parse_syslog_args(
|
||||
args: str=None
|
||||
)-> logging.Handler:
|
||||
def _parse_syslog_args(args: str = None) -> logging.Handler:
|
||||
if not SYSLOG_AVAILABLE:
|
||||
raise ValueError("syslog is not available on this system")
|
||||
if not args:
|
||||
args='/dev/log' if Path('/dev/log').exists() else 'address:localhost:514'
|
||||
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
|
||||
syslog_args = dict()
|
||||
try:
|
||||
for a in args.split(','):
|
||||
arg_name,*arg_value = a.split(':',2)
|
||||
if arg_name=='address':
|
||||
host,*port = arg_value
|
||||
port = 514 if len(port)==0 else int(port[0])
|
||||
syslog_args['address'] = (host,port)
|
||||
elif arg_name=='facility':
|
||||
syslog_args['facility'] = _FACILITY_MAP[arg_value[0]]
|
||||
elif arg_name=='socktype':
|
||||
syslog_args['socktype'] = _SOCK_MAP[arg_value[0]]
|
||||
for a in args.split(","):
|
||||
arg_name, *arg_value = a.split(":", 2)
|
||||
if arg_name == "address":
|
||||
host, *port = arg_value
|
||||
port = 514 if len(port) == 0 else int(port[0])
|
||||
syslog_args["address"] = (host, port)
|
||||
elif arg_name == "facility":
|
||||
syslog_args["facility"] = _FACILITY_MAP[arg_value[0]]
|
||||
elif arg_name == "socktype":
|
||||
syslog_args["socktype"] = _SOCK_MAP[arg_value[0]]
|
||||
else:
|
||||
syslog_args['address'] = arg_name
|
||||
syslog_args["address"] = arg_name
|
||||
except:
|
||||
raise ValueError(f"{args} is not a value argument list for syslog logging")
|
||||
return logging.handlers.SysLogHandler(**syslog_args)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _parse_file_args(args: str=None)-> logging.Handler:
|
||||
def _parse_file_args(args: str = None) -> logging.Handler:
|
||||
if not args:
|
||||
raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'")
|
||||
return logging.FileHandler(args)
|
||||
|
||||
@staticmethod
|
||||
def _parse_http_args(args: str=None)-> logging.Handler:
|
||||
def _parse_http_args(args: str = None) -> logging.Handler:
|
||||
if not args:
|
||||
raise ValueError("please provide destination for http logging using format 'http=url'")
|
||||
arg_list = args.split(',')
|
||||
arg_list = args.split(",")
|
||||
url = urllib.parse.urlparse(arg_list.pop(0))
|
||||
if url.scheme != 'http':
|
||||
if url.scheme != "http":
|
||||
raise ValueError(f"the http logging module can only log to HTTP URLs, but {url.scheme} was specified")
|
||||
host = url.hostname
|
||||
path = url.path
|
||||
port = url.port or 80
|
||||
|
||||
|
||||
syslog_args = dict()
|
||||
for a in arg_list:
|
||||
arg_name, *arg_value = a.split(':',2)
|
||||
if arg_name=='method':
|
||||
arg_value = arg_value[0] if len(arg_value)>0 else 'GET'
|
||||
arg_name, *arg_value = a.split(":", 2)
|
||||
if arg_name == "method":
|
||||
arg_value = arg_value[0] if len(arg_value) > 0 else "GET"
|
||||
syslog_args[arg_name] = arg_value
|
||||
else: # TODO: Provide support for SSL context and credentials
|
||||
pass
|
||||
return logging.handlers.HTTPHandler(f'{host}:{port}',path,**syslog_args)
|
||||
return logging.handlers.HTTPHandler(f"{host}:{port}", path, **syslog_args)
|
||||
|
@ -8,6 +8,8 @@ if torch.backends.mps.is_available():
|
||||
|
||||
|
||||
_torch_layer_norm = torch.nn.functional.layer_norm
|
||||
|
||||
|
||||
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
@ -19,20 +21,26 @@ def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
else:
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
|
||||
|
||||
|
||||
torch.nn.functional.layer_norm = new_layer_norm
|
||||
|
||||
|
||||
_torch_tensor_permute = torch.Tensor.permute
|
||||
|
||||
|
||||
def new_torch_tensor_permute(input, *dims):
|
||||
result = _torch_tensor_permute(input, *dims)
|
||||
if input.device == "mps" and input.dtype == torch.float16:
|
||||
result = result.contiguous()
|
||||
return result
|
||||
|
||||
|
||||
torch.Tensor.permute = new_torch_tensor_permute
|
||||
|
||||
|
||||
_torch_lerp = torch.lerp
|
||||
|
||||
|
||||
def new_torch_lerp(input, end, weight, *, out=None):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
@ -52,20 +60,36 @@ def new_torch_lerp(input, end, weight, *, out=None):
|
||||
else:
|
||||
return _torch_lerp(input, end, weight, out=out)
|
||||
|
||||
|
||||
torch.lerp = new_torch_lerp
|
||||
|
||||
|
||||
_torch_interpolate = torch.nn.functional.interpolate
|
||||
def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
|
||||
|
||||
|
||||
def new_torch_interpolate(
|
||||
input,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
mode="nearest",
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None,
|
||||
antialias=False,
|
||||
):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half()
|
||||
return _torch_interpolate(
|
||||
input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
|
||||
).half()
|
||||
else:
|
||||
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
|
||||
|
||||
|
||||
torch.nn.functional.interpolate = new_torch_interpolate
|
||||
|
||||
# TODO: refactor it
|
||||
_SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor
|
||||
|
||||
|
||||
class ChunkedSlicedAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
@ -78,7 +102,7 @@ class ChunkedSlicedAttnProcessor:
|
||||
|
||||
def __init__(self, slice_size):
|
||||
assert isinstance(slice_size, int)
|
||||
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
|
||||
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
|
||||
self.slice_size = slice_size
|
||||
self._sliced_attn_processor = _SlicedAttnProcessor(slice_size)
|
||||
|
||||
@ -121,7 +145,9 @@ class ChunkedSlicedAttnProcessor:
|
||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
chunk_tmp_tensor = torch.empty(self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device)
|
||||
chunk_tmp_tensor = torch.empty(
|
||||
self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
||||
)
|
||||
|
||||
for i in range(batch_size_attention // self.slice_size):
|
||||
start_idx = i * self.slice_size
|
||||
@ -131,7 +157,15 @@ class ChunkedSlicedAttnProcessor:
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
self.get_attention_scores_chunked(attn, query_slice, key_slice, attn_mask_slice, hidden_states[start_idx:end_idx], value[start_idx:end_idx], chunk_tmp_tensor)
|
||||
self.get_attention_scores_chunked(
|
||||
attn,
|
||||
query_slice,
|
||||
key_slice,
|
||||
attn_mask_slice,
|
||||
hidden_states[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
chunk_tmp_tensor,
|
||||
)
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
@ -150,7 +184,6 @@ class ChunkedSlicedAttnProcessor:
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk):
|
||||
# batch size = 1
|
||||
assert query.shape[0] == 1
|
||||
@ -163,14 +196,14 @@ class ChunkedSlicedAttnProcessor:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
#out_item_size = query.dtype.itemsize
|
||||
#if attn.upcast_attention:
|
||||
# out_item_size = query.dtype.itemsize
|
||||
# if attn.upcast_attention:
|
||||
# out_item_size = torch.float32.itemsize
|
||||
out_item_size = query.element_size()
|
||||
if attn.upcast_attention:
|
||||
out_item_size = 4
|
||||
|
||||
chunk_size = 2 ** 29
|
||||
chunk_size = 2**29
|
||||
|
||||
out_size = query.shape[1] * key.shape[1] * out_item_size
|
||||
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
|
||||
@ -181,8 +214,8 @@ class ChunkedSlicedAttnProcessor:
|
||||
def _get_chunk_view(tensor, start, length):
|
||||
if start + length > tensor.shape[1]:
|
||||
length = tensor.shape[1] - start
|
||||
#print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
|
||||
return tensor[:,start:start+length]
|
||||
# print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
|
||||
return tensor[:, start : start + length]
|
||||
|
||||
for chunk_pos in range(0, query.shape[1], chunk_step):
|
||||
if attention_mask is not None:
|
||||
@ -196,7 +229,7 @@ class ChunkedSlicedAttnProcessor:
|
||||
)
|
||||
else:
|
||||
torch.baddbmm(
|
||||
torch.zeros((1,1,1), device=query.device, dtype=query.dtype),
|
||||
torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype),
|
||||
_get_chunk_view(query, chunk_pos, chunk_step),
|
||||
key,
|
||||
beta=0,
|
||||
@ -206,7 +239,7 @@ class ChunkedSlicedAttnProcessor:
|
||||
chunk = chunk.softmax(dim=-1)
|
||||
torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step))
|
||||
|
||||
#del chunk
|
||||
# del chunk
|
||||
|
||||
|
||||
diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor
|
||||
|
@ -32,9 +32,7 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.load_default()
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(
|
||||
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
|
||||
)
|
||||
lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
@ -81,9 +79,7 @@ def mean_flat(tensor):
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
logger.debug(
|
||||
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
||||
)
|
||||
logger.debug(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
@ -154,21 +150,12 @@ def parallel_data_prefetch(
|
||||
proc = Thread
|
||||
# spawn processes
|
||||
if target_data_type == "ndarray":
|
||||
arguments = [
|
||||
[func, Q, part, i, use_worker_id]
|
||||
for i, part in enumerate(np.array_split(data, n_proc))
|
||||
]
|
||||
arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))]
|
||||
else:
|
||||
step = (
|
||||
int(len(data) / n_proc + 1)
|
||||
if len(data) % n_proc != 0
|
||||
else int(len(data) / n_proc)
|
||||
)
|
||||
step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc)
|
||||
arguments = [
|
||||
[func, Q, part, i, use_worker_id]
|
||||
for i, part in enumerate(
|
||||
[data[i : i + step] for i in range(0, len(data), step)]
|
||||
)
|
||||
for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)])
|
||||
]
|
||||
processes = []
|
||||
for i in range(n_proc):
|
||||
@ -220,9 +207,7 @@ def parallel_data_prefetch(
|
||||
return gather_res
|
||||
|
||||
|
||||
def rand_perlin_2d(
|
||||
shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
|
||||
):
|
||||
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||
d = (shape[0] // res[0], shape[1] // res[1])
|
||||
|
||||
@ -265,9 +250,9 @@ def rand_perlin_2d(
|
||||
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device)
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device)
|
||||
t = fade(grid[: shape[0], : shape[1]])
|
||||
noise = math.sqrt(2) * torch.lerp(
|
||||
torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
|
||||
).to(device)
|
||||
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(
|
||||
device
|
||||
)
|
||||
return noise.to(dtype=torch_dtype(device))
|
||||
|
||||
|
||||
@ -276,9 +261,7 @@ def ask_user(question: str, answers: list):
|
||||
|
||||
user_prompt = f"\n>> {question} {answers}: "
|
||||
invalid_answer_msg = "Invalid answer. Please try again."
|
||||
pose_question = chain(
|
||||
[user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt]))
|
||||
)
|
||||
pose_question = chain([user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt])))
|
||||
user_answers = map(input, pose_question)
|
||||
valid_response = next(filter(answers.__contains__, user_answers))
|
||||
return valid_response
|
||||
@ -303,9 +286,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
|
||||
if dest.is_dir():
|
||||
try:
|
||||
file_name = re.search(
|
||||
'filename="(.+)"', resp.headers.get("Content-Disposition")
|
||||
).group(1)
|
||||
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
|
||||
except:
|
||||
file_name = os.path.basename(url)
|
||||
dest = dest / file_name
|
||||
@ -322,7 +303,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
logger.warning("corrupt existing file found. re-downloading")
|
||||
os.remove(dest)
|
||||
exist_size = 0
|
||||
|
||||
|
||||
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
|
||||
logger.warning(f"{dest}: complete file found. Skipping.")
|
||||
return dest
|
||||
@ -377,16 +358,16 @@ def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format=image_format)
|
||||
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
|
||||
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
|
||||
buffered.getvalue()
|
||||
).decode("UTF-8")
|
||||
image_base64 = f"data:{mime_type};base64," + base64.b64encode(buffered.getvalue()).decode("UTF-8")
|
||||
return image_base64
|
||||
|
||||
|
||||
class Chdir(object):
|
||||
'''Context manager to chdir to desired directory and change back after context exits:
|
||||
"""Context manager to chdir to desired directory and change back after context exits:
|
||||
Args:
|
||||
path (Path): The path to the cwd
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
self.path = path
|
||||
self.original = Path().absolute()
|
||||
@ -394,5 +375,5 @@ class Chdir(object):
|
||||
def __enter__(self):
|
||||
os.chdir(self.path)
|
||||
|
||||
def __exit__(self,*args):
|
||||
def __exit__(self, *args):
|
||||
os.chdir(self.original)
|
||||
|
Reference in New Issue
Block a user