mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
convert remainder of print() to log.info()
This commit is contained in:
@ -96,6 +96,7 @@ from pathlib import Path
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import invokeai.version
|
import invokeai.version
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from invokeai.backend.image_util import retrieve_metadata
|
from invokeai.backend.image_util import retrieve_metadata
|
||||||
|
|
||||||
from .globals import Globals
|
from .globals import Globals
|
||||||
@ -189,7 +190,7 @@ class Args(object):
|
|||||||
print(f"{APP_NAME} {APP_VERSION}")
|
print(f"{APP_NAME} {APP_VERSION}")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
print("* Initializing, be patient...")
|
log.info("Initializing, be patient...")
|
||||||
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
|
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
|
||||||
Globals.try_patchmatch = switches.patchmatch
|
Globals.try_patchmatch = switches.patchmatch
|
||||||
|
|
||||||
@ -197,14 +198,13 @@ class Args(object):
|
|||||||
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
|
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
|
||||||
legacyinit = os.path.expanduser("~/.invokeai")
|
legacyinit = os.path.expanduser("~/.invokeai")
|
||||||
if os.path.exists(initfile):
|
if os.path.exists(initfile):
|
||||||
print(
|
log.info(
|
||||||
f">> Initialization file {initfile} found. Loading...",
|
f"Initialization file {initfile} found. Loading...",
|
||||||
file=sys.stderr,
|
|
||||||
)
|
)
|
||||||
sysargs.insert(0, f"@{initfile}")
|
sysargs.insert(0, f"@{initfile}")
|
||||||
elif os.path.exists(legacyinit):
|
elif os.path.exists(legacyinit):
|
||||||
print(
|
log.warning(
|
||||||
f">> WARNING: Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
|
f"Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
|
||||||
)
|
)
|
||||||
sysargs.insert(0, f"@{legacyinit}")
|
sysargs.insert(0, f"@{legacyinit}")
|
||||||
Globals.log_tokenization = self._arg_parser.parse_args(
|
Globals.log_tokenization = self._arg_parser.parse_args(
|
||||||
@ -214,7 +214,7 @@ class Args(object):
|
|||||||
self._arg_switches = self._arg_parser.parse_args(sysargs)
|
self._arg_switches = self._arg_parser.parse_args(sysargs)
|
||||||
return self._arg_switches
|
return self._arg_switches
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An exception has occurred: {e}")
|
log.error(f"An exception has occurred: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def parse_cmd(self, cmd_string):
|
def parse_cmd(self, cmd_string):
|
||||||
@ -1154,7 +1154,7 @@ class Args(object):
|
|||||||
|
|
||||||
|
|
||||||
def format_metadata(**kwargs):
|
def format_metadata(**kwargs):
|
||||||
print("format_metadata() is deprecated. Please use metadata_dumps()")
|
log.warning("format_metadata() is deprecated. Please use metadata_dumps()")
|
||||||
return metadata_dumps(kwargs)
|
return metadata_dumps(kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -1326,7 +1326,7 @@ def metadata_loads(metadata) -> list:
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
print(">> could not read metadata", file=sys.stderr)
|
log.error("Could not read metadata")
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from .args import metadata_from_png
|
from .args import metadata_from_png
|
||||||
from .generator import infill_methods
|
from .generator import infill_methods
|
||||||
from .globals import Globals, global_cache_dir
|
from .globals import Globals, global_cache_dir
|
||||||
@ -195,12 +196,12 @@ class Generate:
|
|||||||
# device to Generate(). However the device was then ignored, so
|
# device to Generate(). However the device was then ignored, so
|
||||||
# it wasn't actually doing anything. This logic could be reinstated.
|
# it wasn't actually doing anything. This logic could be reinstated.
|
||||||
self.device = torch.device(choose_torch_device())
|
self.device = torch.device(choose_torch_device())
|
||||||
print(f">> Using device_type {self.device.type}")
|
log.info(f"Using device_type {self.device.type}")
|
||||||
if full_precision:
|
if full_precision:
|
||||||
if self.precision != "auto":
|
if self.precision != "auto":
|
||||||
raise ValueError("Remove --full_precision / -F if using --precision")
|
raise ValueError("Remove --full_precision / -F if using --precision")
|
||||||
print("Please remove deprecated --full_precision / -F")
|
log.warning("Please remove deprecated --full_precision / -F")
|
||||||
print("If auto config does not work you can use --precision=float32")
|
log.warning("If auto config does not work you can use --precision=float32")
|
||||||
self.precision = "float32"
|
self.precision = "float32"
|
||||||
if self.precision == "auto":
|
if self.precision == "auto":
|
||||||
self.precision = choose_precision(self.device)
|
self.precision = choose_precision(self.device)
|
||||||
@ -208,13 +209,13 @@ class Generate:
|
|||||||
|
|
||||||
if is_xformers_available():
|
if is_xformers_available():
|
||||||
if torch.cuda.is_available() and not Globals.disable_xformers:
|
if torch.cuda.is_available() and not Globals.disable_xformers:
|
||||||
print(">> xformers memory-efficient attention is available and enabled")
|
log.info("xformers memory-efficient attention is available and enabled")
|
||||||
else:
|
else:
|
||||||
print(
|
log.info(
|
||||||
">> xformers memory-efficient attention is available but disabled"
|
"xformers memory-efficient attention is available but disabled"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(">> xformers not installed")
|
log.info("xformers not installed")
|
||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_manager = ModelManager(
|
self.model_manager = ModelManager(
|
||||||
@ -229,8 +230,8 @@ class Generate:
|
|||||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||||
model = model or fallback
|
model = model or fallback
|
||||||
if not self.model_manager.valid_model(model):
|
if not self.model_manager.valid_model(model):
|
||||||
print(
|
log.warning(
|
||||||
f'** "{model}" is not a known model name; falling back to {fallback}.'
|
f'"{model}" is not a known model name; falling back to {fallback}.'
|
||||||
)
|
)
|
||||||
model = None
|
model = None
|
||||||
self.model_name = model or fallback
|
self.model_name = model or fallback
|
||||||
@ -246,10 +247,10 @@ class Generate:
|
|||||||
|
|
||||||
# load safety checker if requested
|
# load safety checker if requested
|
||||||
if safety_checker:
|
if safety_checker:
|
||||||
print(">> Initializing NSFW checker")
|
log.info("Initializing NSFW checker")
|
||||||
self.safety_checker = SafetyChecker(self.device)
|
self.safety_checker = SafetyChecker(self.device)
|
||||||
else:
|
else:
|
||||||
print(">> NSFW checker is disabled")
|
log.info("NSFW checker is disabled")
|
||||||
|
|
||||||
def prompt2png(self, prompt, outdir, **kwargs):
|
def prompt2png(self, prompt, outdir, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -567,7 +568,7 @@ class Generate:
|
|||||||
self.clear_cuda_cache()
|
self.clear_cuda_cache()
|
||||||
|
|
||||||
if catch_interrupts:
|
if catch_interrupts:
|
||||||
print("**Interrupted** Partial results will be returned.")
|
log.warning("Interrupted** Partial results will be returned.")
|
||||||
else:
|
else:
|
||||||
raise KeyboardInterrupt
|
raise KeyboardInterrupt
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@ -575,11 +576,11 @@ class Generate:
|
|||||||
self.clear_cuda_cache()
|
self.clear_cuda_cache()
|
||||||
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
print(">> Could not generate image.")
|
log.info("Could not generate image.")
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print("\n>> Usage stats:")
|
log.info("Usage stats:")
|
||||||
print(f">> {len(results)} image(s) generated in", "%4.2fs" % (toc - tic))
|
log.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic))
|
||||||
self.print_cuda_stats()
|
self.print_cuda_stats()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -609,16 +610,16 @@ class Generate:
|
|||||||
def print_cuda_stats(self):
|
def print_cuda_stats(self):
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
self.gather_cuda_stats()
|
self.gather_cuda_stats()
|
||||||
print(
|
log.info(
|
||||||
">> Max VRAM used for this generation:",
|
"Max VRAM used for this generation: "+
|
||||||
"%4.2fG." % (self.max_memory_allocated / 1e9),
|
"%4.2fG. " % (self.max_memory_allocated / 1e9)+
|
||||||
"Current VRAM utilization:",
|
"Current VRAM utilization: "+
|
||||||
"%4.2fG" % (self.memory_allocated / 1e9),
|
"%4.2fG" % (self.memory_allocated / 1e9)
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
log.info(
|
||||||
">> Max VRAM used since script start: ",
|
"Max VRAM used since script start: " +
|
||||||
"%4.2fG" % (self.session_peakmem / 1e9),
|
"%4.2fG" % (self.session_peakmem / 1e9)
|
||||||
)
|
)
|
||||||
|
|
||||||
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
|
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
|
||||||
@ -647,7 +648,7 @@ class Generate:
|
|||||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
|
|
||||||
prompt = opt.prompt or args.prompt or ""
|
prompt = opt.prompt or args.prompt or ""
|
||||||
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
|
log.info(f'using seed {seed} and prompt "{prompt}" for {image_path}')
|
||||||
|
|
||||||
# try to reuse the same filename prefix as the original file.
|
# try to reuse the same filename prefix as the original file.
|
||||||
# we take everything up to the first period
|
# we take everything up to the first period
|
||||||
@ -696,8 +697,8 @@ class Generate:
|
|||||||
try:
|
try:
|
||||||
extend_instructions[direction] = int(pixels)
|
extend_instructions[direction] = int(pixels)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print(
|
log.warning(
|
||||||
'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
'invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
||||||
)
|
)
|
||||||
|
|
||||||
opt.seed = seed
|
opt.seed = seed
|
||||||
@ -720,8 +721,8 @@ class Generate:
|
|||||||
# fetch the metadata from the image
|
# fetch the metadata from the image
|
||||||
generator = self.select_generator(embiggen=True)
|
generator = self.select_generator(embiggen=True)
|
||||||
opt.strength = opt.embiggen_strength or 0.40
|
opt.strength = opt.embiggen_strength or 0.40
|
||||||
print(
|
log.info(
|
||||||
f">> Setting img2img strength to {opt.strength} for happy embiggening"
|
f"Setting img2img strength to {opt.strength} for happy embiggening"
|
||||||
)
|
)
|
||||||
generator.generate(
|
generator.generate(
|
||||||
prompt,
|
prompt,
|
||||||
@ -748,12 +749,12 @@ class Generate:
|
|||||||
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
|
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
|
||||||
|
|
||||||
elif tool is None:
|
elif tool is None:
|
||||||
print(
|
log.warning(
|
||||||
"* please provide at least one postprocessing option, such as -G or -U"
|
"please provide at least one postprocessing option, such as -G or -U"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
print(f"* postprocessing tool {tool} is not yet supported")
|
log.warning(f"postprocessing tool {tool} is not yet supported")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def select_generator(
|
def select_generator(
|
||||||
@ -797,8 +798,8 @@ class Generate:
|
|||||||
image = self._load_img(img)
|
image = self._load_img(img)
|
||||||
|
|
||||||
if image.width < self.width and image.height < self.height:
|
if image.width < self.width and image.height < self.height:
|
||||||
print(
|
log.warning(
|
||||||
f">> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
f"img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
||||||
)
|
)
|
||||||
|
|
||||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||||
@ -809,8 +810,8 @@ class Generate:
|
|||||||
if (image.width * image.height) > (
|
if (image.width * image.height) > (
|
||||||
self.width * self.height
|
self.width * self.height
|
||||||
) and self.size_matters:
|
) and self.size_matters:
|
||||||
print(
|
log.info(
|
||||||
">> This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
"This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
||||||
)
|
)
|
||||||
self.size_matters = False
|
self.size_matters = False
|
||||||
|
|
||||||
@ -891,11 +892,11 @@ class Generate:
|
|||||||
try:
|
try:
|
||||||
model_data = cache.get_model(model_name)
|
model_data = cache.get_model(model_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"** model {model_name} could not be loaded: {str(e)}")
|
log.warning(f"model {model_name} could not be loaded: {str(e)}")
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
if previous_model_name is None:
|
if previous_model_name is None:
|
||||||
raise e
|
raise e
|
||||||
print("** trying to reload previous model")
|
log.warning("trying to reload previous model")
|
||||||
model_data = cache.get_model(previous_model_name) # load previous
|
model_data = cache.get_model(previous_model_name) # load previous
|
||||||
if model_data is None:
|
if model_data is None:
|
||||||
raise e
|
raise e
|
||||||
@ -962,15 +963,15 @@ class Generate:
|
|||||||
if self.gfpgan is not None or self.codeformer is not None:
|
if self.gfpgan is not None or self.codeformer is not None:
|
||||||
if facetool == "gfpgan":
|
if facetool == "gfpgan":
|
||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
print(
|
log.info(
|
||||||
">> GFPGAN not found. Face restoration is disabled."
|
"GFPGAN not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = self.gfpgan.process(image, strength, seed)
|
image = self.gfpgan.process(image, strength, seed)
|
||||||
if facetool == "codeformer":
|
if facetool == "codeformer":
|
||||||
if self.codeformer is None:
|
if self.codeformer is None:
|
||||||
print(
|
log.info(
|
||||||
">> CodeFormer not found. Face restoration is disabled."
|
"CodeFormer not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cf_device = (
|
cf_device = (
|
||||||
@ -984,7 +985,7 @@ class Generate:
|
|||||||
fidelity=codeformer_fidelity,
|
fidelity=codeformer_fidelity,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(">> Face Restoration is disabled.")
|
log.info("Face Restoration is disabled.")
|
||||||
if upscale is not None:
|
if upscale is not None:
|
||||||
if self.esrgan is not None:
|
if self.esrgan is not None:
|
||||||
if len(upscale) < 2:
|
if len(upscale) < 2:
|
||||||
@ -997,10 +998,10 @@ class Generate:
|
|||||||
denoise_str=upscale_denoise_str,
|
denoise_str=upscale_denoise_str,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
log.info("ESRGAN is disabled. Image not upscaled.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
log.info(
|
||||||
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
@ -1066,17 +1067,17 @@ class Generate:
|
|||||||
if self.sampler_name in scheduler_map:
|
if self.sampler_name in scheduler_map:
|
||||||
sampler_class = scheduler_map[self.sampler_name]
|
sampler_class = scheduler_map[self.sampler_name]
|
||||||
msg = (
|
msg = (
|
||||||
f">> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
||||||
)
|
)
|
||||||
self.sampler = sampler_class.from_config(self.model.scheduler.config)
|
self.sampler = sampler_class.from_config(self.model.scheduler.config)
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
f">> Unsupported Sampler: {self.sampler_name} "
|
f" Unsupported Sampler: {self.sampler_name} "+
|
||||||
f"Defaulting to {default}"
|
f"Defaulting to {default}"
|
||||||
)
|
)
|
||||||
self.sampler = default
|
self.sampler = default
|
||||||
|
|
||||||
print(msg)
|
log.info(msg)
|
||||||
|
|
||||||
if not hasattr(self.sampler, "uses_inpainting_model"):
|
if not hasattr(self.sampler, "uses_inpainting_model"):
|
||||||
# FIXME: terrible kludge!
|
# FIXME: terrible kludge!
|
||||||
@ -1085,17 +1086,17 @@ class Generate:
|
|||||||
def _load_img(self, img) -> Image:
|
def _load_img(self, img) -> Image:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
image = img
|
image = img
|
||||||
print(f">> using provided input image of size {image.width}x{image.height}")
|
log.info(f"using provided input image of size {image.width}x{image.height}")
|
||||||
elif isinstance(img, str):
|
elif isinstance(img, str):
|
||||||
assert os.path.exists(img), f">> {img}: File not found"
|
assert os.path.exists(img), f">> {img}: File not found"
|
||||||
|
|
||||||
image = Image.open(img)
|
image = Image.open(img)
|
||||||
print(
|
log.info(
|
||||||
f">> loaded input image of size {image.width}x{image.height} from {img}"
|
f"loaded input image of size {image.width}x{image.height} from {img}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = Image.open(img)
|
image = Image.open(img)
|
||||||
print(f">> loaded input image of size {image.width}x{image.height}")
|
log.info(f"loaded input image of size {image.width}x{image.height}")
|
||||||
image = ImageOps.exif_transpose(image)
|
image = ImageOps.exif_transpose(image)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@ -1183,14 +1184,14 @@ class Generate:
|
|||||||
|
|
||||||
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
|
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
|
||||||
if not mask:
|
if not mask:
|
||||||
print(
|
log.info(
|
||||||
">> Initial image has transparent areas. Will inpaint in these regions."
|
"Initial image has transparent areas. Will inpaint in these regions."
|
||||||
)
|
)
|
||||||
if (not force_outpaint) and self._check_for_erasure(image):
|
if (not force_outpaint) and self._check_for_erasure(image):
|
||||||
print(
|
log.info(
|
||||||
">> WARNING: Colors underneath the transparent region seem to have been erased.\n",
|
"Colors underneath the transparent region seem to have been erased.\n" +
|
||||||
">> Inpainting will be suboptimal. Please preserve the colors when making\n",
|
"Inpainting will be suboptimal. Please preserve the colors when making\n" +
|
||||||
">> a transparency mask, or provide mask explicitly using --init_mask (-M).",
|
"a transparency mask, or provide mask explicitly using --init_mask (-M)."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _squeeze_image(self, image):
|
def _squeeze_image(self, image):
|
||||||
@ -1201,11 +1202,11 @@ class Generate:
|
|||||||
|
|
||||||
def _fit_image(self, image, max_dimensions):
|
def _fit_image(self, image, max_dimensions):
|
||||||
w, h = max_dimensions
|
w, h = max_dimensions
|
||||||
print(f">> image will be resized to fit inside a box {w}x{h} in size.")
|
log.info(f"image will be resized to fit inside a box {w}x{h} in size.")
|
||||||
# note that InitImageResizer does the multiple of 64 truncation internally
|
# note that InitImageResizer does the multiple of 64 truncation internally
|
||||||
image = InitImageResizer(image).resize(width=w, height=h)
|
image = InitImageResizer(image).resize(width=w, height=h)
|
||||||
print(
|
log.info(
|
||||||
f">> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
f"after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
||||||
)
|
)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@ -1216,8 +1217,8 @@ class Generate:
|
|||||||
) # resize to integer multiple of 64
|
) # resize to integer multiple of 64
|
||||||
if h != height or w != width:
|
if h != height or w != width:
|
||||||
if log:
|
if log:
|
||||||
print(
|
log.info(
|
||||||
f">> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
f"Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
||||||
)
|
)
|
||||||
height = h
|
height = h
|
||||||
width = w
|
width = w
|
||||||
|
@ -25,6 +25,7 @@ from typing import Callable, List, Iterator, Optional, Type
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from ..image_util import configure_model_padding
|
from ..image_util import configure_model_padding
|
||||||
from ..util.util import rand_perlin_2d
|
from ..util.util import rand_perlin_2d
|
||||||
from ..safety_checker import SafetyChecker
|
from ..safety_checker import SafetyChecker
|
||||||
@ -372,7 +373,7 @@ class Generator:
|
|||||||
try:
|
try:
|
||||||
x_T = self.get_noise(width, height)
|
x_T = self.get_noise(width, height)
|
||||||
except:
|
except:
|
||||||
print("** An error occurred while getting initial noise **")
|
log.error("An error occurred while getting initial noise")
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||||
@ -607,7 +608,7 @@ class Generator:
|
|||||||
image = self.sample_to_image(sample)
|
image = self.sample_to_image(sample)
|
||||||
dirname = os.path.dirname(filepath) or "."
|
dirname = os.path.dirname(filepath) or "."
|
||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
print(f"** creating directory {dirname}")
|
log.info(f"creating directory {dirname}")
|
||||||
os.makedirs(dirname, exist_ok=True)
|
os.makedirs(dirname, exist_ok=True)
|
||||||
image.save(filepath, "PNG")
|
image.save(filepath, "PNG")
|
||||||
|
|
||||||
|
@ -8,10 +8,11 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
|
|
||||||
from .base import Generator
|
from .base import Generator
|
||||||
from .img2img import Img2Img
|
from .img2img import Img2Img
|
||||||
|
|
||||||
|
|
||||||
class Embiggen(Generator):
|
class Embiggen(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
super().__init__(model, precision)
|
super().__init__(model, precision)
|
||||||
@ -72,22 +73,22 @@ class Embiggen(Generator):
|
|||||||
embiggen = [1.0] # If not specified, assume no scaling
|
embiggen = [1.0] # If not specified, assume no scaling
|
||||||
elif embiggen[0] < 0:
|
elif embiggen[0] < 0:
|
||||||
embiggen[0] = 1.0
|
embiggen[0] = 1.0
|
||||||
print(
|
log.warning(
|
||||||
">> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||||
)
|
)
|
||||||
if len(embiggen) < 2:
|
if len(embiggen) < 2:
|
||||||
embiggen.append(0.75)
|
embiggen.append(0.75)
|
||||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||||
embiggen[1] = 0.75
|
embiggen[1] = 0.75
|
||||||
print(
|
log.warning(
|
||||||
">> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||||
)
|
)
|
||||||
if len(embiggen) < 3:
|
if len(embiggen) < 3:
|
||||||
embiggen.append(0.25)
|
embiggen.append(0.25)
|
||||||
elif embiggen[2] < 0:
|
elif embiggen[2] < 0:
|
||||||
embiggen[2] = 0.25
|
embiggen[2] = 0.25
|
||||||
print(
|
log.warning(
|
||||||
">> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||||
@ -97,8 +98,8 @@ class Embiggen(Generator):
|
|||||||
embiggen_tiles.sort()
|
embiggen_tiles.sort()
|
||||||
|
|
||||||
if strength >= 0.5:
|
if strength >= 0.5:
|
||||||
print(
|
log.warning(
|
||||||
f"* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prep img2img generator, since we wrap over it
|
# Prep img2img generator, since we wrap over it
|
||||||
@ -121,8 +122,8 @@ class Embiggen(Generator):
|
|||||||
from ..restoration.realesrgan import ESRGAN
|
from ..restoration.realesrgan import ESRGAN
|
||||||
|
|
||||||
esrgan = ESRGAN()
|
esrgan = ESRGAN()
|
||||||
print(
|
log.info(
|
||||||
f">> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||||
)
|
)
|
||||||
if embiggen[0] > 2:
|
if embiggen[0] > 2:
|
||||||
initsuperimage = esrgan.process(
|
initsuperimage = esrgan.process(
|
||||||
@ -312,10 +313,10 @@ class Embiggen(Generator):
|
|||||||
def make_image():
|
def make_image():
|
||||||
# Make main tiles -------------------------------------------------
|
# Make main tiles -------------------------------------------------
|
||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
print(f">> Making {len(embiggen_tiles)} Embiggen tiles...")
|
log.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||||
else:
|
else:
|
||||||
print(
|
log.info(
|
||||||
f">> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||||
)
|
)
|
||||||
|
|
||||||
emb_tile_store = []
|
emb_tile_store = []
|
||||||
@ -361,11 +362,11 @@ class Embiggen(Generator):
|
|||||||
# newinitimage.save(newinitimagepath)
|
# newinitimage.save(newinitimagepath)
|
||||||
|
|
||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
print(
|
log.debug(
|
||||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
log.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||||
|
|
||||||
# create a torch tensor from an Image
|
# create a torch tensor from an Image
|
||||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||||
@ -547,8 +548,8 @@ class Embiggen(Generator):
|
|||||||
# Layer tile onto final image
|
# Layer tile onto final image
|
||||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||||
else:
|
else:
|
||||||
print(
|
log.error(
|
||||||
"Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||||
)
|
)
|
||||||
|
|
||||||
# after internal loops and patching up return Embiggen image
|
# after internal loops and patching up return Embiggen image
|
||||||
|
@ -14,6 +14,8 @@ from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeli
|
|||||||
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
||||||
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
|
|
||||||
class Txt2Img2Img(Generator):
|
class Txt2Img2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
super().__init__(model, precision)
|
super().__init__(model, precision)
|
||||||
@ -77,8 +79,8 @@ class Txt2Img2Img(Generator):
|
|||||||
# the message below is accurate.
|
# the message below is accurate.
|
||||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||||
print(
|
log.info(
|
||||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||||
)
|
)
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
|
@ -5,10 +5,9 @@ wraps the actual patchmatch object. It respects the global
|
|||||||
be suppressed or deferred
|
be suppressed or deferred
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
class PatchMatch:
|
class PatchMatch:
|
||||||
"""
|
"""
|
||||||
Thin class wrapper around the patchmatch function.
|
Thin class wrapper around the patchmatch function.
|
||||||
@ -28,12 +27,12 @@ class PatchMatch:
|
|||||||
from patchmatch import patch_match as pm
|
from patchmatch import patch_match as pm
|
||||||
|
|
||||||
if pm.patchmatch_available:
|
if pm.patchmatch_available:
|
||||||
print(">> Patchmatch initialized")
|
log.info("Patchmatch initialized")
|
||||||
else:
|
else:
|
||||||
print(">> Patchmatch not loaded (nonfatal)")
|
log.info("Patchmatch not loaded (nonfatal)")
|
||||||
self.patch_match = pm
|
self.patch_match = pm
|
||||||
else:
|
else:
|
||||||
print(">> Patchmatch loading disabled")
|
log.info("Patchmatch loading disabled")
|
||||||
self.tried_load = True
|
self.tried_load = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -30,9 +30,9 @@ work fine.
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from torchvision import transforms
|
|
||||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from invokeai.backend.globals import global_cache_dir
|
from invokeai.backend.globals import global_cache_dir
|
||||||
|
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||||
@ -83,7 +83,7 @@ class Txt2Mask(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device="cpu", refined=False):
|
def __init__(self, device="cpu", refined=False):
|
||||||
print(">> Initializing clipseg model for text to mask inference")
|
log.info("Initializing clipseg model for text to mask inference")
|
||||||
|
|
||||||
# BUG: we are not doing anything with the device option at this time
|
# BUG: we are not doing anything with the device option at this time
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -101,18 +101,6 @@ class Txt2Mask(object):
|
|||||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||||
pixels indicate where the object is inferred to be.
|
pixels indicate where the object is inferred to be.
|
||||||
"""
|
"""
|
||||||
transform = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(
|
|
||||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
||||||
),
|
|
||||||
transforms.Resize(
|
|
||||||
(CLIPSEG_SIZE, CLIPSEG_SIZE)
|
|
||||||
), # must be multiple of 64...
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if type(image) is str:
|
if type(image) is str:
|
||||||
image = Image.open(image).convert("RGB")
|
image = Image.open(image).convert("RGB")
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ from typing import Union
|
|||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
||||||
|
|
||||||
from .model_manager import ModelManager, SDLegacyType
|
from .model_manager import ModelManager, SDLegacyType
|
||||||
@ -372,9 +373,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
unet_key = "model.diffusion_model."
|
unet_key = "model.diffusion_model."
|
||||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||||
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
|
log.debug(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||||
if extract_ema:
|
if extract_ema:
|
||||||
print(" | Extracting EMA weights (usually better for inference)")
|
log.debug("Extracting EMA weights (usually better for inference)")
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("model.diffusion_model"):
|
if key.startswith("model.diffusion_model"):
|
||||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||||
@ -392,8 +393,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
key
|
key
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
log.debug(
|
||||||
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
"Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||||
)
|
)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
@ -1115,7 +1116,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
if "global_step" in checkpoint:
|
if "global_step" in checkpoint:
|
||||||
global_step = checkpoint["global_step"]
|
global_step = checkpoint["global_step"]
|
||||||
else:
|
else:
|
||||||
print(" | global_step key not found in model")
|
log.debug("global_step key not found in model")
|
||||||
global_step = None
|
global_step = None
|
||||||
|
|
||||||
# sometimes there is a state_dict key and sometimes not
|
# sometimes there is a state_dict key and sometimes not
|
||||||
@ -1229,15 +1230,15 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
# If a replacement VAE path was specified, we'll incorporate that into
|
# If a replacement VAE path was specified, we'll incorporate that into
|
||||||
# the checkpoint model and then convert it
|
# the checkpoint model and then convert it
|
||||||
if vae_path:
|
if vae_path:
|
||||||
print(f" | Converting VAE {vae_path}")
|
log.debug(f"Converting VAE {vae_path}")
|
||||||
replace_checkpoint_vae(checkpoint,vae_path)
|
replace_checkpoint_vae(checkpoint,vae_path)
|
||||||
# otherwise we use the original VAE, provided that
|
# otherwise we use the original VAE, provided that
|
||||||
# an externally loaded diffusers VAE was not passed
|
# an externally loaded diffusers VAE was not passed
|
||||||
elif not vae:
|
elif not vae:
|
||||||
print(" | Using checkpoint model's original VAE")
|
log.debug("Using checkpoint model's original VAE")
|
||||||
|
|
||||||
if vae:
|
if vae:
|
||||||
print(" | Using replacement diffusers VAE")
|
log.debug("Using replacement diffusers VAE")
|
||||||
else: # convert the original or replacement VAE
|
else: # convert the original or replacement VAE
|
||||||
vae_config = create_vae_diffusers_config(
|
vae_config = create_vae_diffusers_config(
|
||||||
original_config, image_size=image_size
|
original_config, image_size=image_size
|
||||||
|
@ -18,6 +18,7 @@ from compel.prompt_parser import (
|
|||||||
PromptParser,
|
PromptParser,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||||
@ -162,8 +163,8 @@ def log_tokenization(
|
|||||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
):
|
):
|
||||||
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
|
log.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||||
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
log.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||||
|
|
||||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||||
log_tokenization_for_prompt_object(
|
log_tokenization_for_prompt_object(
|
||||||
@ -237,12 +238,12 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
|
|||||||
usedTokens += 1
|
usedTokens += 1
|
||||||
|
|
||||||
if usedTokens > 0:
|
if usedTokens > 0:
|
||||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
log.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||||
print(f"{tokenized}\x1b[0m")
|
log.debug(f"{tokenized}\x1b[0m")
|
||||||
|
|
||||||
if discarded != "":
|
if discarded != "":
|
||||||
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
log.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||||
print(f"{discarded}\x1b[0m")
|
log.debug(f"{discarded}\x1b[0m")
|
||||||
|
|
||||||
|
|
||||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
||||||
@ -295,8 +296,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
|||||||
return parsed_prompts
|
return parsed_prompts
|
||||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||||
if weight_sum == 0:
|
if weight_sum == 0:
|
||||||
print(
|
log.warning(
|
||||||
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead."
|
"Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||||
)
|
)
|
||||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import invokeai.backend.util.logging as log
|
||||||
|
|
||||||
class Restoration:
|
class Restoration:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
@ -8,17 +10,17 @@ class Restoration:
|
|||||||
# Load GFPGAN
|
# Load GFPGAN
|
||||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
||||||
if gfpgan.gfpgan_model_exists:
|
if gfpgan.gfpgan_model_exists:
|
||||||
print(">> GFPGAN Initialized")
|
log.info("GFPGAN Initialized")
|
||||||
else:
|
else:
|
||||||
print(">> GFPGAN Disabled")
|
log.info("GFPGAN Disabled")
|
||||||
gfpgan = None
|
gfpgan = None
|
||||||
|
|
||||||
# Load CodeFormer
|
# Load CodeFormer
|
||||||
codeformer = self.load_codeformer()
|
codeformer = self.load_codeformer()
|
||||||
if codeformer.codeformer_model_exists:
|
if codeformer.codeformer_model_exists:
|
||||||
print(">> CodeFormer Initialized")
|
log.info("CodeFormer Initialized")
|
||||||
else:
|
else:
|
||||||
print(">> CodeFormer Disabled")
|
log.info("CodeFormer Disabled")
|
||||||
codeformer = None
|
codeformer = None
|
||||||
|
|
||||||
return gfpgan, codeformer
|
return gfpgan, codeformer
|
||||||
@ -39,5 +41,5 @@ class Restoration:
|
|||||||
from .realesrgan import ESRGAN
|
from .realesrgan import ESRGAN
|
||||||
|
|
||||||
esrgan = ESRGAN(esrgan_bg_tile)
|
esrgan = ESRGAN(esrgan_bg_tile)
|
||||||
print(">> ESRGAN Initialized")
|
log.info("ESRGAN Initialized")
|
||||||
return esrgan
|
return esrgan
|
||||||
|
@ -5,6 +5,7 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from ..globals import Globals
|
from ..globals import Globals
|
||||||
|
|
||||||
pretrained_model_url = (
|
pretrained_model_url = (
|
||||||
@ -23,12 +24,12 @@ class CodeFormerRestoration:
|
|||||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
if not self.codeformer_model_exists:
|
if not self.codeformer_model_exists:
|
||||||
print("## NOT FOUND: CodeFormer model not found at " + self.model_path)
|
log.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||||
sys.path.append(os.path.abspath(codeformer_dir))
|
sys.path.append(os.path.abspath(codeformer_dir))
|
||||||
|
|
||||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
print(f">> CodeFormer - Restoring Faces for image seed:{seed}")
|
log.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
@ -97,7 +98,7 @@ class CodeFormerRestoration:
|
|||||||
del output
|
del output
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except RuntimeError as error:
|
except RuntimeError as error:
|
||||||
print(f"\tFailed inference for CodeFormer: {error}.")
|
log.error(f"Failed inference for CodeFormer: {error}.")
|
||||||
restored_face = cropped_face
|
restored_face = cropped_face
|
||||||
|
|
||||||
restored_face = restored_face.astype("uint8")
|
restored_face = restored_face.astype("uint8")
|
||||||
|
@ -6,9 +6,9 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
class GFPGAN:
|
class GFPGAN:
|
||||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
||||||
if not os.path.isabs(gfpgan_model_path):
|
if not os.path.isabs(gfpgan_model_path):
|
||||||
@ -19,7 +19,7 @@ class GFPGAN:
|
|||||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
if not self.gfpgan_model_exists:
|
if not self.gfpgan_model_exists:
|
||||||
print("## NOT FOUND: GFPGAN model not found at " + self.model_path)
|
log.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_exists(self):
|
def model_exists(self):
|
||||||
@ -27,7 +27,7 @@ class GFPGAN:
|
|||||||
|
|
||||||
def process(self, image, strength: float, seed: str = None):
|
def process(self, image, strength: float, seed: str = None):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
print(f">> GFPGAN - Restoring Faces for image seed:{seed}")
|
log.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
@ -47,14 +47,14 @@ class GFPGAN:
|
|||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
print(">> Error loading GFPGAN:", file=sys.stderr)
|
log.error("Error loading GFPGAN:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
os.chdir(cwd)
|
os.chdir(cwd)
|
||||||
|
|
||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
print(f">> WARNING: GFPGAN not initialized.")
|
log.warning("WARNING: GFPGAN not initialized.")
|
||||||
print(
|
log.warning(
|
||||||
f">> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
|
|
||||||
class Outcrop(object):
|
class Outcrop(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -82,7 +82,7 @@ class Outcrop(object):
|
|||||||
pixels = extents[direction]
|
pixels = extents[direction]
|
||||||
# round pixels up to the nearest 64
|
# round pixels up to the nearest 64
|
||||||
pixels = math.ceil(pixels / 64) * 64
|
pixels = math.ceil(pixels / 64) * 64
|
||||||
print(f">> extending image {direction}ward by {pixels} pixels")
|
log.info(f"extending image {direction}ward by {pixels} pixels")
|
||||||
image = self._rotate(image, direction)
|
image = self._rotate(image, direction)
|
||||||
image = self._extend(image, pixels)
|
image = self._extend(image, pixels)
|
||||||
image = self._rotate(image, direction, reverse=True)
|
image = self._rotate(image, direction, reverse=True)
|
||||||
|
@ -6,18 +6,13 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
class ESRGAN:
|
class ESRGAN:
|
||||||
def __init__(self, bg_tile_size=400) -> None:
|
def __init__(self, bg_tile_size=400) -> None:
|
||||||
self.bg_tile_size = bg_tile_size
|
self.bg_tile_size = bg_tile_size
|
||||||
|
|
||||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
|
||||||
use_half_precision = False
|
|
||||||
else:
|
|
||||||
use_half_precision = True
|
|
||||||
|
|
||||||
def load_esrgan_bg_upsampler(self, denoise_str):
|
def load_esrgan_bg_upsampler(self, denoise_str):
|
||||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||||
use_half_precision = False
|
use_half_precision = False
|
||||||
@ -74,16 +69,16 @@ class ESRGAN:
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
print(">> Error loading Real-ESRGAN:", file=sys.stderr)
|
log.error("Error loading Real-ESRGAN:")
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
if upsampler_scale == 0:
|
if upsampler_scale == 0:
|
||||||
print(">> Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
log.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
||||||
return image
|
return image
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
print(
|
log.info(
|
||||||
f">> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
||||||
)
|
)
|
||||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
@ -14,6 +14,7 @@ from PIL import Image, ImageFilter
|
|||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
import invokeai.assets.web as web_assets
|
import invokeai.assets.web as web_assets
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from .globals import global_cache_dir
|
from .globals import global_cache_dir
|
||||||
from .util import CPU_DEVICE
|
from .util import CPU_DEVICE
|
||||||
|
|
||||||
@ -40,8 +41,8 @@ class SafetyChecker(object):
|
|||||||
cache_dir=safety_model_path,
|
cache_dir=safety_model_path,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(
|
log.error(
|
||||||
"** An error was encountered while installing the safety checker:"
|
"An error was encountered while installing the safety checker:"
|
||||||
)
|
)
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
@ -65,8 +66,8 @@ class SafetyChecker(object):
|
|||||||
)
|
)
|
||||||
self.safety_checker.to(CPU_DEVICE) # offload
|
self.safety_checker.to(CPU_DEVICE) # offload
|
||||||
if has_nsfw_concept[0]:
|
if has_nsfw_concept[0]:
|
||||||
print(
|
log.warning(
|
||||||
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
"An image with potential non-safe content has been detected. A blurred image will be returned."
|
||||||
)
|
)
|
||||||
return self.blur(image)
|
return self.blur(image)
|
||||||
else:
|
else:
|
||||||
|
@ -17,6 +17,7 @@ from huggingface_hub import (
|
|||||||
hf_hub_url,
|
hf_hub_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
@ -66,11 +67,11 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||||
self.concept_list.extend(list(local_concepts_to_add))
|
self.concept_list.extend(list(local_concepts_to_add))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
log.warning(
|
||||||
f" ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||||
)
|
)
|
||||||
print(
|
log.warning(
|
||||||
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||||
)
|
)
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
|
|
||||||
@ -81,7 +82,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
be downloaded.
|
be downloaded.
|
||||||
"""
|
"""
|
||||||
if not concept_name in self.list_concepts():
|
if not concept_name in self.list_concepts():
|
||||||
print(
|
log.warning(
|
||||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
@ -219,7 +220,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
if chunk == 0:
|
if chunk == 0:
|
||||||
bytes += total
|
bytes += total
|
||||||
|
|
||||||
print(f">> Downloading {repo_id}...", end="")
|
log.info(f"Downloading {repo_id}...", end="")
|
||||||
try:
|
try:
|
||||||
for file in (
|
for file in (
|
||||||
"README.md",
|
"README.md",
|
||||||
@ -233,22 +234,22 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
)
|
)
|
||||||
except ul_error.HTTPError as e:
|
except ul_error.HTTPError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
print(
|
log.warning(
|
||||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
log.warning(
|
||||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
||||||
)
|
)
|
||||||
os.rmdir(dest)
|
os.rmdir(dest)
|
||||||
return False
|
return False
|
||||||
except ul_error.URLError as e:
|
except ul_error.URLError as e:
|
||||||
print(
|
log.error(
|
||||||
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||||
)
|
)
|
||||||
os.rmdir(dest)
|
os.rmdir(dest)
|
||||||
return False
|
return False
|
||||||
print("...{:.2f}Kb".format(bytes / 1024))
|
log.info("...{:.2f}Kb".format(bytes / 1024))
|
||||||
return succeeded
|
return succeeded
|
||||||
|
|
||||||
def _concept_id(self, concept_name: str) -> str:
|
def _concept_id(self, concept_name: str) -> str:
|
||||||
|
@ -14,9 +14,9 @@ from diffusers.models.cross_attention import AttnProcessor
|
|||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from ...util import torch_dtype
|
from ...util import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionType(enum.Enum):
|
class CrossAttentionType(enum.Enum):
|
||||||
SELF = 1
|
SELF = 1
|
||||||
TOKENS = 2
|
TOKENS = 2
|
||||||
@ -425,13 +425,13 @@ def get_cross_attention_modules(
|
|||||||
expected_count = 16
|
expected_count = 16
|
||||||
if cross_attention_modules_in_model_count != expected_count:
|
if cross_attention_modules_in_model_count != expected_count:
|
||||||
# non-fatal error but .swap() won't work.
|
# non-fatal error but .swap() won't work.
|
||||||
print(
|
log.error(
|
||||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
||||||
+ f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||||
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
||||||
+ f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
||||||
+ f"work properly until it is fixed."
|
+ "work properly until it is fixed."
|
||||||
)
|
)
|
||||||
return attention_module_tuples
|
return attention_module_tuples
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ import torch
|
|||||||
from diffusers.models.cross_attention import AttnProcessor
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
@ -262,7 +263,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
# TODO remove when compvis codepath support is dropped
|
# TODO remove when compvis codepath support is dropped
|
||||||
if step_index is None and sigma is None:
|
if step_index is None and sigma is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Either step_index or sigma is required when doing cross attention control, but both are None."
|
"Either step_index or sigma is required when doing cross attention control, but both are None."
|
||||||
)
|
)
|
||||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||||
return percent_through
|
return percent_through
|
||||||
@ -466,10 +467,14 @@ class InvokeAIDiffuserComponent:
|
|||||||
outside = torch.count_nonzero(
|
outside = torch.count_nonzero(
|
||||||
(latents < -current_threshold) | (latents > current_threshold)
|
(latents < -current_threshold) | (latents > current_threshold)
|
||||||
)
|
)
|
||||||
print(
|
log.info(
|
||||||
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
|
||||||
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
)
|
||||||
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
|
log.debug(
|
||||||
|
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
|
||||||
|
)
|
||||||
|
log.debug(
|
||||||
|
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||||
)
|
)
|
||||||
|
|
||||||
if maxval < current_threshold and minval > -current_threshold:
|
if maxval < current_threshold and minval > -current_threshold:
|
||||||
@ -496,9 +501,11 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.debug_thresholding:
|
if self.debug_thresholding:
|
||||||
print(
|
log.debug(
|
||||||
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
|
||||||
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
|
)
|
||||||
|
log.debug(
|
||||||
|
f"{num_altered / latents.numel() * 100:.2f}% values altered"
|
||||||
)
|
)
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
@ -599,7 +606,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# below is fugly omg
|
# below is fugly omg
|
||||||
num_actual_conditionings = len(c_or_weighted_c_list)
|
|
||||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
weights = [1] + [weight for c, weight in weighted_cond_list]
|
||||||
chunk_count = ceil(len(conditionings) / 2)
|
chunk_count = ceil(len(conditionings) / 2)
|
||||||
|
@ -10,7 +10,7 @@ from torchvision.utils import make_grid
|
|||||||
|
|
||||||
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||||
|
|
||||||
|
|
||||||
@ -191,7 +191,7 @@ def mkdirs(paths):
|
|||||||
def mkdir_and_rename(path):
|
def mkdir_and_rename(path):
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
new_name = path + "_archived_" + get_timestamp()
|
new_name = path + "_archived_" + get_timestamp()
|
||||||
print("Path already exists. Rename it to [{:s}]".format(new_name))
|
log.error("Path already exists. Rename it to [{:s}]".format(new_name))
|
||||||
os.replace(path, new_name)
|
os.replace(path, new_name)
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from compel.embeddings_provider import BaseTextualInversionManager
|
|||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -59,12 +60,12 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||||
): # in case a token with literal angle brackets encountered
|
): # in case a token with literal angle brackets encountered
|
||||||
print(f">> Loaded local embedding for trigger {concept_name}")
|
log.info(f"Loaded local embedding for trigger {concept_name}")
|
||||||
continue
|
continue
|
||||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||||
if not bin_file:
|
if not bin_file:
|
||||||
continue
|
continue
|
||||||
print(f">> Loaded remote embedding for trigger {concept_name}")
|
log.info(f"Loaded remote embedding for trigger {concept_name}")
|
||||||
self.load_textual_inversion(bin_file)
|
self.load_textual_inversion(bin_file)
|
||||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||||
|
|
||||||
@ -85,8 +86,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||||
for embedding_info in embedding_list:
|
for embedding_info in embedding_list:
|
||||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||||
print(
|
log.warning(
|
||||||
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -105,8 +106,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if ckpt_path.name == "learned_embeds.bin"
|
if ckpt_path.name == "learned_embeds.bin"
|
||||||
else f"<{ckpt_path.stem}>"
|
else f"<{ckpt_path.stem}>"
|
||||||
)
|
)
|
||||||
print(
|
log.info(
|
||||||
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||||
)
|
)
|
||||||
trigger_str = replacement_trigger_str
|
trigger_str = replacement_trigger_str
|
||||||
|
|
||||||
@ -120,8 +121,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
log.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
|
||||||
print(f" | The error was {str(e)}")
|
log.debug(f"The error was {str(e)}")
|
||||||
|
|
||||||
def _add_textual_inversion(
|
def _add_textual_inversion(
|
||||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||||
@ -133,8 +134,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
:return: The token id for the added embedding, either existing or newly-added.
|
:return: The token id for the added embedding, either existing or newly-added.
|
||||||
"""
|
"""
|
||||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||||
print(
|
log.warning(
|
||||||
f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if not self.full_precision:
|
if not self.full_precision:
|
||||||
@ -155,11 +156,11 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if str(e).startswith("Warning"):
|
if str(e).startswith("Warning"):
|
||||||
print(f">> {str(e)}")
|
log.warning(f"{str(e)}")
|
||||||
else:
|
else:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print(
|
log.error(
|
||||||
f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -219,16 +220,16 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
for ti in self.textual_inversions:
|
for ti in self.textual_inversions:
|
||||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||||
if ti.embedding_vector_length > 1:
|
if ti.embedding_vector_length > 1:
|
||||||
print(
|
log.info(
|
||||||
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
|
f"Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self._inject_tokens_and_assign_embeddings(ti)
|
self._inject_tokens_and_assign_embeddings(ti)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(
|
log.debug(
|
||||||
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
|
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||||
)
|
)
|
||||||
print(f" | The error was {str(e)}")
|
log.debug(f"The error was {str(e)}")
|
||||||
continue
|
continue
|
||||||
injected_token_ids.append(ti.trigger_token_id)
|
injected_token_ids.append(ti.trigger_token_id)
|
||||||
injected_token_ids.extend(ti.pad_token_ids)
|
injected_token_ids.extend(ti.pad_token_ids)
|
||||||
@ -306,16 +307,16 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if suffix in [".pt",".ckpt",".bin"]:
|
if suffix in [".pt",".ckpt",".bin"]:
|
||||||
scan_result = scan_file_path(embedding_file)
|
scan_result = scan_file_path(embedding_file)
|
||||||
if scan_result.infected_files > 0:
|
if scan_result.infected_files > 0:
|
||||||
print(
|
log.critical(
|
||||||
f" ** Security Issues Found in Model: {scan_result.issues_count}"
|
f"Security Issues Found in Model: {scan_result.issues_count}"
|
||||||
)
|
)
|
||||||
print(" ** For your safety, InvokeAI will not load this embed.")
|
log.critical("For your safety, InvokeAI will not load this embed.")
|
||||||
return list()
|
return list()
|
||||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||||
else:
|
else:
|
||||||
ckpt = safetensors.torch.load_file(embedding_file)
|
ckpt = safetensors.torch.load_file(embedding_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
log.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||||
return list()
|
return list()
|
||||||
|
|
||||||
# try to figure out what kind of embedding file it is and parse accordingly
|
# try to figure out what kind of embedding file it is and parse accordingly
|
||||||
@ -334,7 +335,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||||
basename = Path(file_path).stem
|
basename = Path(file_path).stem
|
||||||
print(f' | Loading v1 embedding file: {basename}')
|
log.debug(f'Loading v1 embedding file: {basename}')
|
||||||
|
|
||||||
embeddings = list()
|
embeddings = list()
|
||||||
token_counter = -1
|
token_counter = -1
|
||||||
@ -342,7 +343,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if token_counter < 0:
|
if token_counter < 0:
|
||||||
trigger = embedding_ckpt["name"]
|
trigger = embedding_ckpt["name"]
|
||||||
elif token_counter == 0:
|
elif token_counter == 0:
|
||||||
trigger = f'<basename>'
|
trigger = '<basename>'
|
||||||
else:
|
else:
|
||||||
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||||
token_counter += 1
|
token_counter += 1
|
||||||
@ -365,7 +366,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
This handles embedding .pt file variant #2.
|
This handles embedding .pt file variant #2.
|
||||||
"""
|
"""
|
||||||
basename = Path(file_path).stem
|
basename = Path(file_path).stem
|
||||||
print(f' | Loading v2 embedding file: {basename}')
|
log.debug(f'Loading v2 embedding file: {basename}')
|
||||||
embeddings = list()
|
embeddings = list()
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
@ -384,7 +385,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
)
|
)
|
||||||
embeddings.append(embedding_info)
|
embeddings.append(embedding_info)
|
||||||
else:
|
else:
|
||||||
print(f" ** {basename}: Unrecognized embedding format")
|
log.warning(f"{basename}: Unrecognized embedding format")
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@ -393,7 +394,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||||
"""
|
"""
|
||||||
basename = Path(file_path).stem
|
basename = Path(file_path).stem
|
||||||
print(f' | Loading v3 embedding file: {basename}')
|
log.debug(f'Loading v3 embedding file: {basename}')
|
||||||
embedding = embedding_ckpt['emb_params']
|
embedding = embedding_ckpt['emb_params']
|
||||||
embedding_info = EmbeddingInfo(
|
embedding_info = EmbeddingInfo(
|
||||||
name = f'<{basename}>',
|
name = f'<{basename}>',
|
||||||
@ -411,11 +412,11 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
basename = Path(filepath).stem
|
basename = Path(filepath).stem
|
||||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||||
|
|
||||||
print(f' | Loading v4 embedding file: {short_path}')
|
log.debug(f'Loading v4 embedding file: {short_path}')
|
||||||
|
|
||||||
embeddings = list()
|
embeddings = list()
|
||||||
if list(embedding_ckpt.keys()) == 0:
|
if list(embedding_ckpt.keys()) == 0:
|
||||||
print(f" ** Invalid embeddings file: {short_path}")
|
log.warning(f"Invalid embeddings file: {short_path}")
|
||||||
else:
|
else:
|
||||||
for token,embedding in embedding_ckpt.items():
|
for token,embedding in embedding_ckpt.items():
|
||||||
embedding_info = EmbeddingInfo(
|
embedding_info = EmbeddingInfo(
|
||||||
|
@ -18,6 +18,7 @@ import torch
|
|||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
from .devices import torch_dtype
|
from .devices import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
@ -38,7 +39,7 @@ def log_txt_as_img(wh, xc, size=10):
|
|||||||
try:
|
try:
|
||||||
draw.text((0, 0), lines, fill="black", font=font)
|
draw.text((0, 0), lines, fill="black", font=font)
|
||||||
except UnicodeEncodeError:
|
except UnicodeEncodeError:
|
||||||
print("Cant encode string for logging. Skipping.")
|
log.warning("Cant encode string for logging. Skipping.")
|
||||||
|
|
||||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||||
txts.append(txt)
|
txts.append(txt)
|
||||||
@ -80,8 +81,8 @@ def mean_flat(tensor):
|
|||||||
def count_params(model, verbose=False):
|
def count_params(model, verbose=False):
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
log.debug(
|
||||||
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
||||||
)
|
)
|
||||||
return total_params
|
return total_params
|
||||||
|
|
||||||
@ -132,8 +133,8 @@ def parallel_data_prefetch(
|
|||||||
raise ValueError("list expected but function got ndarray.")
|
raise ValueError("list expected but function got ndarray.")
|
||||||
elif isinstance(data, abc.Iterable):
|
elif isinstance(data, abc.Iterable):
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
print(
|
log.warning(
|
||||||
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||||
)
|
)
|
||||||
data = list(data.values())
|
data = list(data.values())
|
||||||
if target_data_type == "ndarray":
|
if target_data_type == "ndarray":
|
||||||
@ -175,7 +176,7 @@ def parallel_data_prefetch(
|
|||||||
processes += [p]
|
processes += [p]
|
||||||
|
|
||||||
# start processes
|
# start processes
|
||||||
print("Start prefetching...")
|
log.info("Start prefetching...")
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -194,7 +195,7 @@ def parallel_data_prefetch(
|
|||||||
gather_res[res[0]] = res[1]
|
gather_res[res[0]] = res[1]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Exception: ", e)
|
log.error("Exception: ", e)
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.terminate()
|
p.terminate()
|
||||||
|
|
||||||
@ -202,7 +203,7 @@ def parallel_data_prefetch(
|
|||||||
finally:
|
finally:
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.join()
|
p.join()
|
||||||
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
log.info(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||||
|
|
||||||
if target_data_type == "ndarray":
|
if target_data_type == "ndarray":
|
||||||
if not isinstance(gather_res[0], np.ndarray):
|
if not isinstance(gather_res[0], np.ndarray):
|
||||||
@ -318,23 +319,23 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
resp = requests.get(url, headers=header, stream=True) # new request with range
|
resp = requests.get(url, headers=header, stream=True) # new request with range
|
||||||
|
|
||||||
if exist_size > content_length:
|
if exist_size > content_length:
|
||||||
print("* corrupt existing file found. re-downloading")
|
log.warning("corrupt existing file found. re-downloading")
|
||||||
os.remove(dest)
|
os.remove(dest)
|
||||||
exist_size = 0
|
exist_size = 0
|
||||||
|
|
||||||
if resp.status_code == 416 or exist_size == content_length:
|
if resp.status_code == 416 or exist_size == content_length:
|
||||||
print(f"* {dest}: complete file found. Skipping.")
|
log.warning(f"{dest}: complete file found. Skipping.")
|
||||||
return dest
|
return dest
|
||||||
elif resp.status_code == 206 or exist_size > 0:
|
elif resp.status_code == 206 or exist_size > 0:
|
||||||
print(f"* {dest}: partial file found. Resuming...")
|
log.warning(f"{dest}: partial file found. Resuming...")
|
||||||
elif resp.status_code != 200:
|
elif resp.status_code != 200:
|
||||||
print(f"** An error occurred during downloading {dest}: {resp.reason}")
|
log.error(f"An error occurred during downloading {dest}: {resp.reason}")
|
||||||
else:
|
else:
|
||||||
print(f"* {dest}: Downloading...")
|
log.error(f"{dest}: Downloading...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if content_length < 2000:
|
if content_length < 2000:
|
||||||
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
|
log.error(f"ERROR DOWNLOADING {url}: {resp.text}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with open(dest, open_mode) as file, tqdm(
|
with open(dest, open_mode) as file, tqdm(
|
||||||
@ -349,7 +350,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
size = file.write(data)
|
size = file.write(data)
|
||||||
bar.update(size)
|
bar.update(size)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred while downloading {dest}: {str(e)}")
|
log.error(f"An error occurred while downloading {dest}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return dest
|
return dest
|
||||||
|
@ -19,6 +19,7 @@ from PIL import Image
|
|||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as log
|
||||||
import invokeai.frontend.web.dist as frontend
|
import invokeai.frontend.web.dist as frontend
|
||||||
|
|
||||||
from .. import Generate
|
from .. import Generate
|
||||||
@ -213,7 +214,7 @@ class InvokeAIWebServer:
|
|||||||
self.load_socketio_listeners(self.socketio)
|
self.load_socketio_listeners(self.socketio)
|
||||||
|
|
||||||
if args.gui:
|
if args.gui:
|
||||||
print(">> Launching Invoke AI GUI")
|
log.info("Launching Invoke AI GUI")
|
||||||
try:
|
try:
|
||||||
from flaskwebgui import FlaskUI
|
from flaskwebgui import FlaskUI
|
||||||
|
|
||||||
@ -231,17 +232,17 @@ class InvokeAIWebServer:
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
else:
|
else:
|
||||||
useSSL = args.certfile or args.keyfile
|
useSSL = args.certfile or args.keyfile
|
||||||
print(">> Started Invoke AI Web Server")
|
log.info("Started Invoke AI Web Server")
|
||||||
if self.host == "0.0.0.0":
|
if self.host == "0.0.0.0":
|
||||||
print(
|
log.info(
|
||||||
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
|
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
log.info(
|
||||||
">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
"Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
||||||
)
|
)
|
||||||
print(
|
log.info(
|
||||||
f">> Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
|
f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
|
||||||
)
|
)
|
||||||
if not useSSL:
|
if not useSSL:
|
||||||
self.socketio.run(app=self.app, host=self.host, port=self.port)
|
self.socketio.run(app=self.app, host=self.host, port=self.port)
|
||||||
@ -290,7 +291,7 @@ class InvokeAIWebServer:
|
|||||||
def load_socketio_listeners(self, socketio):
|
def load_socketio_listeners(self, socketio):
|
||||||
@socketio.on("requestSystemConfig")
|
@socketio.on("requestSystemConfig")
|
||||||
def handle_request_capabilities():
|
def handle_request_capabilities():
|
||||||
print(">> System config requested")
|
log.info("System config requested")
|
||||||
config = self.get_system_config()
|
config = self.get_system_config()
|
||||||
config["model_list"] = self.generate.model_manager.list_models()
|
config["model_list"] = self.generate.model_manager.list_models()
|
||||||
config["infill_methods"] = infill_methods()
|
config["infill_methods"] = infill_methods()
|
||||||
@ -330,7 +331,7 @@ class InvokeAIWebServer:
|
|||||||
if model_name in current_model_list:
|
if model_name in current_model_list:
|
||||||
update = True
|
update = True
|
||||||
|
|
||||||
print(f">> Adding New Model: {model_name}")
|
log.info(f"Adding New Model: {model_name}")
|
||||||
|
|
||||||
self.generate.model_manager.add_model(
|
self.generate.model_manager.add_model(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -348,14 +349,14 @@ class InvokeAIWebServer:
|
|||||||
"update": update,
|
"update": update,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(f">> New Model Added: {model_name}")
|
log.info(f"New Model Added: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@socketio.on("deleteModel")
|
@socketio.on("deleteModel")
|
||||||
def handle_delete_model(model_name: str):
|
def handle_delete_model(model_name: str):
|
||||||
try:
|
try:
|
||||||
print(f">> Deleting Model: {model_name}")
|
log.info(f"Deleting Model: {model_name}")
|
||||||
self.generate.model_manager.del_model(model_name)
|
self.generate.model_manager.del_model(model_name)
|
||||||
self.generate.model_manager.commit(opt.conf)
|
self.generate.model_manager.commit(opt.conf)
|
||||||
updated_model_list = self.generate.model_manager.list_models()
|
updated_model_list = self.generate.model_manager.list_models()
|
||||||
@ -366,14 +367,14 @@ class InvokeAIWebServer:
|
|||||||
"model_list": updated_model_list,
|
"model_list": updated_model_list,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(f">> Model Deleted: {model_name}")
|
log.info(f"Model Deleted: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@socketio.on("requestModelChange")
|
@socketio.on("requestModelChange")
|
||||||
def handle_set_model(model_name: str):
|
def handle_set_model(model_name: str):
|
||||||
try:
|
try:
|
||||||
print(f">> Model change requested: {model_name}")
|
log.info(f"Model change requested: {model_name}")
|
||||||
model = self.generate.set_model(model_name)
|
model = self.generate.set_model(model_name)
|
||||||
model_list = self.generate.model_manager.list_models()
|
model_list = self.generate.model_manager.list_models()
|
||||||
if model is None:
|
if model is None:
|
||||||
@ -454,7 +455,7 @@ class InvokeAIWebServer:
|
|||||||
"update": True,
|
"update": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(f">> Model Converted: {model_name}")
|
log.info(f"Model Converted: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@ -490,7 +491,7 @@ class InvokeAIWebServer:
|
|||||||
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||||
"vae", None
|
"vae", None
|
||||||
):
|
):
|
||||||
print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
log.info(f"Using configured VAE assigned to {models_to_merge[0]}")
|
||||||
merged_model_config.update(vae=vae)
|
merged_model_config.update(vae=vae)
|
||||||
|
|
||||||
self.generate.model_manager.import_diffuser_model(
|
self.generate.model_manager.import_diffuser_model(
|
||||||
@ -507,8 +508,8 @@ class InvokeAIWebServer:
|
|||||||
"update": True,
|
"update": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(f">> Models Merged: {models_to_merge}")
|
log.info(f"Models Merged: {models_to_merge}")
|
||||||
print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
log.info(f"New Model Added: {model_merge_info['merged_model_name']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@ -698,7 +699,7 @@ class InvokeAIWebServer:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f">> Unable to load {path}")
|
log.info(f"Unable to load {path}")
|
||||||
socketio.emit(
|
socketio.emit(
|
||||||
"error", {"message": f"Unable to load {path}: {str(e)}"}
|
"error", {"message": f"Unable to load {path}: {str(e)}"}
|
||||||
)
|
)
|
||||||
@ -735,9 +736,9 @@ class InvokeAIWebServer:
|
|||||||
printable_parameters["init_mask"][:64] + "..."
|
printable_parameters["init_mask"][:64] + "..."
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"\n>> Image Generation Parameters:\n\n{printable_parameters}\n")
|
log.info(f"Image Generation Parameters:\n\n{printable_parameters}\n")
|
||||||
print(f">> ESRGAN Parameters: {esrgan_parameters}")
|
log.info(f"ESRGAN Parameters: {esrgan_parameters}")
|
||||||
print(f">> Facetool Parameters: {facetool_parameters}")
|
log.info(f"Facetool Parameters: {facetool_parameters}")
|
||||||
|
|
||||||
self.generate_images(
|
self.generate_images(
|
||||||
generation_parameters,
|
generation_parameters,
|
||||||
@ -750,8 +751,8 @@ class InvokeAIWebServer:
|
|||||||
@socketio.on("runPostprocessing")
|
@socketio.on("runPostprocessing")
|
||||||
def handle_run_postprocessing(original_image, postprocessing_parameters):
|
def handle_run_postprocessing(original_image, postprocessing_parameters):
|
||||||
try:
|
try:
|
||||||
print(
|
log.info(
|
||||||
f'>> Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
||||||
)
|
)
|
||||||
|
|
||||||
progress = Progress()
|
progress = Progress()
|
||||||
@ -861,14 +862,14 @@ class InvokeAIWebServer:
|
|||||||
|
|
||||||
@socketio.on("cancel")
|
@socketio.on("cancel")
|
||||||
def handle_cancel():
|
def handle_cancel():
|
||||||
print(">> Cancel processing requested")
|
log.info("Cancel processing requested")
|
||||||
self.canceled.set()
|
self.canceled.set()
|
||||||
|
|
||||||
# TODO: I think this needs a safety mechanism.
|
# TODO: I think this needs a safety mechanism.
|
||||||
@socketio.on("deleteImage")
|
@socketio.on("deleteImage")
|
||||||
def handle_delete_image(url, thumbnail, uuid, category):
|
def handle_delete_image(url, thumbnail, uuid, category):
|
||||||
try:
|
try:
|
||||||
print(f'>> Delete requested "{url}"')
|
log.info(f'Delete requested "{url}"')
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
path = self.get_image_path_from_url(url)
|
path = self.get_image_path_from_url(url)
|
||||||
@ -1263,7 +1264,7 @@ class InvokeAIWebServer:
|
|||||||
image, os.path.basename(path), self.thumbnail_image_path
|
image, os.path.basename(path), self.thumbnail_image_path
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f'\n\n>> Image generated: "{path}"\n')
|
log.info(f'Image generated: "{path}"\n')
|
||||||
self.write_log_message(f'[Generated] "{path}": {command}')
|
self.write_log_message(f'[Generated] "{path}": {command}')
|
||||||
|
|
||||||
if progress.total_iterations > progress.current_iteration:
|
if progress.total_iterations > progress.current_iteration:
|
||||||
@ -1329,7 +1330,7 @@ class InvokeAIWebServer:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Clear the CUDA cache on an exception
|
# Clear the CUDA cache on an exception
|
||||||
self.empty_cuda_cache()
|
self.empty_cuda_cache()
|
||||||
print(e)
|
log.error(e)
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
def empty_cuda_cache(self):
|
def empty_cuda_cache(self):
|
||||||
|
Reference in New Issue
Block a user