convert remainder of print() to log.info()

This commit is contained in:
Lincoln Stein 2023-04-14 15:15:14 -04:00
parent c132dbdefa
commit 0b0e6fe448
22 changed files with 262 additions and 259 deletions

View File

@ -96,6 +96,7 @@ from pathlib import Path
from typing import List
import invokeai.version
import invokeai.backend.util.logging as log
from invokeai.backend.image_util import retrieve_metadata
from .globals import Globals
@ -189,7 +190,7 @@ class Args(object):
print(f"{APP_NAME} {APP_VERSION}")
sys.exit(0)
print("* Initializing, be patient...")
log.info("Initializing, be patient...")
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
Globals.try_patchmatch = switches.patchmatch
@ -197,14 +198,13 @@ class Args(object):
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
legacyinit = os.path.expanduser("~/.invokeai")
if os.path.exists(initfile):
print(
f">> Initialization file {initfile} found. Loading...",
file=sys.stderr,
log.info(
f"Initialization file {initfile} found. Loading...",
)
sysargs.insert(0, f"@{initfile}")
elif os.path.exists(legacyinit):
print(
f">> WARNING: Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
log.warning(
f"Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
)
sysargs.insert(0, f"@{legacyinit}")
Globals.log_tokenization = self._arg_parser.parse_args(
@ -214,7 +214,7 @@ class Args(object):
self._arg_switches = self._arg_parser.parse_args(sysargs)
return self._arg_switches
except Exception as e:
print(f"An exception has occurred: {e}")
log.error(f"An exception has occurred: {e}")
return None
def parse_cmd(self, cmd_string):
@ -1154,7 +1154,7 @@ class Args(object):
def format_metadata(**kwargs):
print("format_metadata() is deprecated. Please use metadata_dumps()")
log.warning("format_metadata() is deprecated. Please use metadata_dumps()")
return metadata_dumps(kwargs)
@ -1326,7 +1326,7 @@ def metadata_loads(metadata) -> list:
import sys
import traceback
print(">> could not read metadata", file=sys.stderr)
log.error("Could not read metadata")
print(traceback.format_exc(), file=sys.stderr)
return results

View File

@ -27,6 +27,7 @@ from diffusers.utils.import_utils import is_xformers_available
from omegaconf import OmegaConf
from pathlib import Path
import invokeai.backend.util.logging as log
from .args import metadata_from_png
from .generator import infill_methods
from .globals import Globals, global_cache_dir
@ -195,12 +196,12 @@ class Generate:
# device to Generate(). However the device was then ignored, so
# it wasn't actually doing anything. This logic could be reinstated.
self.device = torch.device(choose_torch_device())
print(f">> Using device_type {self.device.type}")
log.info(f"Using device_type {self.device.type}")
if full_precision:
if self.precision != "auto":
raise ValueError("Remove --full_precision / -F if using --precision")
print("Please remove deprecated --full_precision / -F")
print("If auto config does not work you can use --precision=float32")
log.warning("Please remove deprecated --full_precision / -F")
log.warning("If auto config does not work you can use --precision=float32")
self.precision = "float32"
if self.precision == "auto":
self.precision = choose_precision(self.device)
@ -208,13 +209,13 @@ class Generate:
if is_xformers_available():
if torch.cuda.is_available() and not Globals.disable_xformers:
print(">> xformers memory-efficient attention is available and enabled")
log.info("xformers memory-efficient attention is available and enabled")
else:
print(
">> xformers memory-efficient attention is available but disabled"
log.info(
"xformers memory-efficient attention is available but disabled"
)
else:
print(">> xformers not installed")
log.info("xformers not installed")
# model caching system for fast switching
self.model_manager = ModelManager(
@ -229,8 +230,8 @@ class Generate:
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
model = model or fallback
if not self.model_manager.valid_model(model):
print(
f'** "{model}" is not a known model name; falling back to {fallback}.'
log.warning(
f'"{model}" is not a known model name; falling back to {fallback}.'
)
model = None
self.model_name = model or fallback
@ -246,10 +247,10 @@ class Generate:
# load safety checker if requested
if safety_checker:
print(">> Initializing NSFW checker")
log.info("Initializing NSFW checker")
self.safety_checker = SafetyChecker(self.device)
else:
print(">> NSFW checker is disabled")
log.info("NSFW checker is disabled")
def prompt2png(self, prompt, outdir, **kwargs):
"""
@ -567,7 +568,7 @@ class Generate:
self.clear_cuda_cache()
if catch_interrupts:
print("**Interrupted** Partial results will be returned.")
log.warning("Interrupted** Partial results will be returned.")
else:
raise KeyboardInterrupt
except RuntimeError:
@ -575,11 +576,11 @@ class Generate:
self.clear_cuda_cache()
print(traceback.format_exc(), file=sys.stderr)
print(">> Could not generate image.")
log.info("Could not generate image.")
toc = time.time()
print("\n>> Usage stats:")
print(f">> {len(results)} image(s) generated in", "%4.2fs" % (toc - tic))
log.info("Usage stats:")
log.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic))
self.print_cuda_stats()
return results
@ -609,16 +610,16 @@ class Generate:
def print_cuda_stats(self):
if self._has_cuda():
self.gather_cuda_stats()
print(
">> Max VRAM used for this generation:",
"%4.2fG." % (self.max_memory_allocated / 1e9),
"Current VRAM utilization:",
"%4.2fG" % (self.memory_allocated / 1e9),
log.info(
"Max VRAM used for this generation: "+
"%4.2fG. " % (self.max_memory_allocated / 1e9)+
"Current VRAM utilization: "+
"%4.2fG" % (self.memory_allocated / 1e9)
)
print(
">> Max VRAM used since script start: ",
"%4.2fG" % (self.session_peakmem / 1e9),
log.info(
"Max VRAM used since script start: " +
"%4.2fG" % (self.session_peakmem / 1e9)
)
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
@ -647,7 +648,7 @@ class Generate:
seed = random.randrange(0, np.iinfo(np.uint32).max)
prompt = opt.prompt or args.prompt or ""
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
log.info(f'using seed {seed} and prompt "{prompt}" for {image_path}')
# try to reuse the same filename prefix as the original file.
# we take everything up to the first period
@ -696,8 +697,8 @@ class Generate:
try:
extend_instructions[direction] = int(pixels)
except ValueError:
print(
'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
log.warning(
'invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
)
opt.seed = seed
@ -720,8 +721,8 @@ class Generate:
# fetch the metadata from the image
generator = self.select_generator(embiggen=True)
opt.strength = opt.embiggen_strength or 0.40
print(
f">> Setting img2img strength to {opt.strength} for happy embiggening"
log.info(
f"Setting img2img strength to {opt.strength} for happy embiggening"
)
generator.generate(
prompt,
@ -748,12 +749,12 @@ class Generate:
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
elif tool is None:
print(
"* please provide at least one postprocessing option, such as -G or -U"
log.warning(
"please provide at least one postprocessing option, such as -G or -U"
)
return None
else:
print(f"* postprocessing tool {tool} is not yet supported")
log.warning(f"postprocessing tool {tool} is not yet supported")
return None
def select_generator(
@ -797,8 +798,8 @@ class Generate:
image = self._load_img(img)
if image.width < self.width and image.height < self.height:
print(
f">> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
log.warning(
f"img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
)
# if image has a transparent area and no mask was provided, then try to generate mask
@ -809,8 +810,8 @@ class Generate:
if (image.width * image.height) > (
self.width * self.height
) and self.size_matters:
print(
">> This input is larger than your defaults. If you run out of memory, please use a smaller image."
log.info(
"This input is larger than your defaults. If you run out of memory, please use a smaller image."
)
self.size_matters = False
@ -891,11 +892,11 @@ class Generate:
try:
model_data = cache.get_model(model_name)
except Exception as e:
print(f"** model {model_name} could not be loaded: {str(e)}")
log.warning(f"model {model_name} could not be loaded: {str(e)}")
print(traceback.format_exc(), file=sys.stderr)
if previous_model_name is None:
raise e
print("** trying to reload previous model")
log.warning("trying to reload previous model")
model_data = cache.get_model(previous_model_name) # load previous
if model_data is None:
raise e
@ -962,15 +963,15 @@ class Generate:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
print(
">> GFPGAN not found. Face restoration is disabled."
log.info(
"GFPGAN not found. Face restoration is disabled."
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
print(
">> CodeFormer not found. Face restoration is disabled."
log.info(
"CodeFormer not found. Face restoration is disabled."
)
else:
cf_device = (
@ -984,7 +985,7 @@ class Generate:
fidelity=codeformer_fidelity,
)
else:
print(">> Face Restoration is disabled.")
log.info("Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
@ -997,10 +998,10 @@ class Generate:
denoise_str=upscale_denoise_str,
)
else:
print(">> ESRGAN is disabled. Image not upscaled.")
log.info("ESRGAN is disabled. Image not upscaled.")
except Exception as e:
print(
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
log.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
)
if image_callback is not None:
@ -1066,17 +1067,17 @@ class Generate:
if self.sampler_name in scheduler_map:
sampler_class = scheduler_map[self.sampler_name]
msg = (
f">> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
)
self.sampler = sampler_class.from_config(self.model.scheduler.config)
else:
msg = (
f">> Unsupported Sampler: {self.sampler_name} "
f" Unsupported Sampler: {self.sampler_name} "+
f"Defaulting to {default}"
)
self.sampler = default
print(msg)
log.info(msg)
if not hasattr(self.sampler, "uses_inpainting_model"):
# FIXME: terrible kludge!
@ -1085,17 +1086,17 @@ class Generate:
def _load_img(self, img) -> Image:
if isinstance(img, Image.Image):
image = img
print(f">> using provided input image of size {image.width}x{image.height}")
log.info(f"using provided input image of size {image.width}x{image.height}")
elif isinstance(img, str):
assert os.path.exists(img), f">> {img}: File not found"
image = Image.open(img)
print(
f">> loaded input image of size {image.width}x{image.height} from {img}"
log.info(
f"loaded input image of size {image.width}x{image.height} from {img}"
)
else:
image = Image.open(img)
print(f">> loaded input image of size {image.width}x{image.height}")
log.info(f"loaded input image of size {image.width}x{image.height}")
image = ImageOps.exif_transpose(image)
return image
@ -1183,14 +1184,14 @@ class Generate:
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
if not mask:
print(
">> Initial image has transparent areas. Will inpaint in these regions."
log.info(
"Initial image has transparent areas. Will inpaint in these regions."
)
if (not force_outpaint) and self._check_for_erasure(image):
print(
">> WARNING: Colors underneath the transparent region seem to have been erased.\n",
">> Inpainting will be suboptimal. Please preserve the colors when making\n",
">> a transparency mask, or provide mask explicitly using --init_mask (-M).",
if (not force_outpaint) and self._check_for_erasure(image):
log.info(
"Colors underneath the transparent region seem to have been erased.\n" +
"Inpainting will be suboptimal. Please preserve the colors when making\n" +
"a transparency mask, or provide mask explicitly using --init_mask (-M)."
)
def _squeeze_image(self, image):
@ -1201,11 +1202,11 @@ class Generate:
def _fit_image(self, image, max_dimensions):
w, h = max_dimensions
print(f">> image will be resized to fit inside a box {w}x{h} in size.")
log.info(f"image will be resized to fit inside a box {w}x{h} in size.")
# note that InitImageResizer does the multiple of 64 truncation internally
image = InitImageResizer(image).resize(width=w, height=h)
print(
f">> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
log.info(
f"after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
)
return image
@ -1216,8 +1217,8 @@ class Generate:
) # resize to integer multiple of 64
if h != height or w != width:
if log:
print(
f">> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
log.info(
f"Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
)
height = h
width = w

View File

@ -25,6 +25,7 @@ from typing import Callable, List, Iterator, Optional, Type
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
import invokeai.backend.util.logging as log
from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
@ -372,7 +373,7 @@ class Generator:
try:
x_T = self.get_noise(width, height)
except:
print("** An error occurred while getting initial noise **")
log.error("An error occurred while getting initial noise")
print(traceback.format_exc())
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
@ -607,7 +608,7 @@ class Generator:
image = self.sample_to_image(sample)
dirname = os.path.dirname(filepath) or "."
if not os.path.exists(dirname):
print(f"** creating directory {dirname}")
log.info(f"creating directory {dirname}")
os.makedirs(dirname, exist_ok=True)
image.save(filepath, "PNG")

View File

@ -8,10 +8,11 @@ import torch
from PIL import Image
from tqdm import trange
import invokeai.backend.util.logging as log
from .base import Generator
from .img2img import Img2Img
class Embiggen(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
@ -72,22 +73,22 @@ class Embiggen(Generator):
embiggen = [1.0] # If not specified, assume no scaling
elif embiggen[0] < 0:
embiggen[0] = 1.0
print(
">> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
log.warning(
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
)
if len(embiggen) < 2:
embiggen.append(0.75)
elif embiggen[1] > 1.0 or embiggen[1] < 0:
embiggen[1] = 0.75
print(
">> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
log.warning(
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
)
if len(embiggen) < 3:
embiggen.append(0.25)
elif embiggen[2] < 0:
embiggen[2] = 0.25
print(
">> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
log.warning(
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
)
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
@ -97,8 +98,8 @@ class Embiggen(Generator):
embiggen_tiles.sort()
if strength >= 0.5:
print(
f"* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
log.warning(
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
)
# Prep img2img generator, since we wrap over it
@ -121,8 +122,8 @@ class Embiggen(Generator):
from ..restoration.realesrgan import ESRGAN
esrgan = ESRGAN()
print(
f">> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
log.info(
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
)
if embiggen[0] > 2:
initsuperimage = esrgan.process(
@ -312,10 +313,10 @@ class Embiggen(Generator):
def make_image():
# Make main tiles -------------------------------------------------
if embiggen_tiles:
print(f">> Making {len(embiggen_tiles)} Embiggen tiles...")
log.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
else:
print(
f">> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
log.info(
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
)
emb_tile_store = []
@ -361,11 +362,11 @@ class Embiggen(Generator):
# newinitimage.save(newinitimagepath)
if embiggen_tiles:
print(
log.debug(
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
)
else:
print(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
log.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
# create a torch tensor from an Image
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
@ -547,8 +548,8 @@ class Embiggen(Generator):
# Layer tile onto final image
outputsuperimage.alpha_composite(intileimage, (left, top))
else:
print(
"Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
log.error(
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
)
# after internal loops and patching up return Embiggen image

View File

@ -14,6 +14,8 @@ from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeli
from ..stable_diffusion.diffusers_pipeline import ConditioningData
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
import invokeai.backend.util.logging as log
class Txt2Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
@ -77,8 +79,8 @@ class Txt2Img2Img(Generator):
# the message below is accurate.
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
print(
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
log.info(
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
)
# resizing

View File

@ -5,10 +5,9 @@ wraps the actual patchmatch object. It respects the global
be suppressed or deferred
"""
import numpy as np
import invokeai.backend.util.logging as log
from invokeai.backend.globals import Globals
class PatchMatch:
"""
Thin class wrapper around the patchmatch function.
@ -28,12 +27,12 @@ class PatchMatch:
from patchmatch import patch_match as pm
if pm.patchmatch_available:
print(">> Patchmatch initialized")
log.info("Patchmatch initialized")
else:
print(">> Patchmatch not loaded (nonfatal)")
log.info("Patchmatch not loaded (nonfatal)")
self.patch_match = pm
else:
print(">> Patchmatch loading disabled")
log.info("Patchmatch loading disabled")
self.tried_load = True
@classmethod

View File

@ -30,9 +30,9 @@ work fine.
import numpy as np
import torch
from PIL import Image, ImageOps
from torchvision import transforms
from transformers import AutoProcessor, CLIPSegForImageSegmentation
import invokeai.backend.util.logging as log
from invokeai.backend.globals import global_cache_dir
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
@ -83,7 +83,7 @@ class Txt2Mask(object):
"""
def __init__(self, device="cpu", refined=False):
print(">> Initializing clipseg model for text to mask inference")
log.info("Initializing clipseg model for text to mask inference")
# BUG: we are not doing anything with the device option at this time
self.device = device
@ -101,18 +101,6 @@ class Txt2Mask(object):
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
"""
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
transforms.Resize(
(CLIPSEG_SIZE, CLIPSEG_SIZE)
), # must be multiple of 64...
]
)
if type(image) is str:
image = Image.open(image).convert("RGB")

View File

@ -25,6 +25,7 @@ from typing import Union
import torch
from safetensors.torch import load_file
import invokeai.backend.util.logging as log
from invokeai.backend.globals import global_cache_dir, global_config_dir
from .model_manager import ModelManager, SDLegacyType
@ -372,9 +373,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
log.debug(f"Checkpoint {path} has both EMA and non-EMA weights.")
if extract_ema:
print(" | Extracting EMA weights (usually better for inference)")
log.debug("Extracting EMA weights (usually better for inference)")
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
@ -392,8 +393,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
key
)
else:
print(
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
log.debug(
"Extracting only the non-EMA weights (usually better for fine-tuning)"
)
for key in keys:
@ -1115,7 +1116,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print(" | global_step key not found in model")
log.debug("global_step key not found in model")
global_step = None
# sometimes there is a state_dict key and sometimes not
@ -1229,15 +1230,15 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
# If a replacement VAE path was specified, we'll incorporate that into
# the checkpoint model and then convert it
if vae_path:
print(f" | Converting VAE {vae_path}")
log.debug(f"Converting VAE {vae_path}")
replace_checkpoint_vae(checkpoint,vae_path)
# otherwise we use the original VAE, provided that
# an externally loaded diffusers VAE was not passed
elif not vae:
print(" | Using checkpoint model's original VAE")
log.debug("Using checkpoint model's original VAE")
if vae:
print(" | Using replacement diffusers VAE")
log.debug("Using replacement diffusers VAE")
else: # convert the original or replacement VAE
vae_config = create_vae_diffusers_config(
original_config, image_size=image_size

View File

@ -18,6 +18,7 @@ from compel.prompt_parser import (
PromptParser,
)
import invokeai.backend.util.logging as log
from invokeai.backend.globals import Globals
from ..stable_diffusion import InvokeAIDiffuserComponent
@ -162,8 +163,8 @@ def log_tokenization(
negative_prompt: Union[Blend, FlattenedPrompt],
tokenizer,
):
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
log.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
log.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
log_tokenization_for_prompt_object(
@ -237,12 +238,12 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
usedTokens += 1
if usedTokens > 0:
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
print(f"{tokenized}\x1b[0m")
log.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
log.debug(f"{tokenized}\x1b[0m")
if discarded != "":
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
print(f"{discarded}\x1b[0m")
log.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
log.debug(f"{discarded}\x1b[0m")
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
@ -295,8 +296,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead."
log.warning(
"Subprompt weights add up to zero. Discarding and using even weights instead."
)
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]

View File

@ -1,3 +1,5 @@
import invokeai.backend.util.logging as log
class Restoration:
def __init__(self) -> None:
pass
@ -8,17 +10,17 @@ class Restoration:
# Load GFPGAN
gfpgan = self.load_gfpgan(gfpgan_model_path)
if gfpgan.gfpgan_model_exists:
print(">> GFPGAN Initialized")
log.info("GFPGAN Initialized")
else:
print(">> GFPGAN Disabled")
log.info("GFPGAN Disabled")
gfpgan = None
# Load CodeFormer
codeformer = self.load_codeformer()
if codeformer.codeformer_model_exists:
print(">> CodeFormer Initialized")
log.info("CodeFormer Initialized")
else:
print(">> CodeFormer Disabled")
log.info("CodeFormer Disabled")
codeformer = None
return gfpgan, codeformer
@ -39,5 +41,5 @@ class Restoration:
from .realesrgan import ESRGAN
esrgan = ESRGAN(esrgan_bg_tile)
print(">> ESRGAN Initialized")
log.info("ESRGAN Initialized")
return esrgan

View File

@ -5,6 +5,7 @@ import warnings
import numpy as np
import torch
import invokeai.backend.util.logging as log
from ..globals import Globals
pretrained_model_url = (
@ -23,12 +24,12 @@ class CodeFormerRestoration:
self.codeformer_model_exists = os.path.isfile(self.model_path)
if not self.codeformer_model_exists:
print("## NOT FOUND: CodeFormer model not found at " + self.model_path)
log.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75):
if seed is not None:
print(f">> CodeFormer - Restoring Faces for image seed:{seed}")
log.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
@ -97,7 +98,7 @@ class CodeFormerRestoration:
del output
torch.cuda.empty_cache()
except RuntimeError as error:
print(f"\tFailed inference for CodeFormer: {error}.")
log.error(f"Failed inference for CodeFormer: {error}.")
restored_face = cropped_face
restored_face = restored_face.astype("uint8")

View File

@ -6,9 +6,9 @@ import numpy as np
import torch
from PIL import Image
import invokeai.backend.util.logging as log
from invokeai.backend.globals import Globals
class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
if not os.path.isabs(gfpgan_model_path):
@ -19,7 +19,7 @@ class GFPGAN:
self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists:
print("## NOT FOUND: GFPGAN model not found at " + self.model_path)
log.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
return None
def model_exists(self):
@ -27,7 +27,7 @@ class GFPGAN:
def process(self, image, strength: float, seed: str = None):
if seed is not None:
print(f">> GFPGAN - Restoring Faces for image seed:{seed}")
log.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
@ -47,14 +47,14 @@ class GFPGAN:
except Exception:
import traceback
print(">> Error loading GFPGAN:", file=sys.stderr)
log.error("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
os.chdir(cwd)
if self.gfpgan is None:
print(f">> WARNING: GFPGAN not initialized.")
print(
f">> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
log.warning("WARNING: GFPGAN not initialized.")
log.warning(
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
)
image = image.convert("RGB")

View File

@ -1,7 +1,7 @@
import math
from PIL import Image
import invokeai.backend.util.logging as log
class Outcrop(object):
def __init__(
@ -82,7 +82,7 @@ class Outcrop(object):
pixels = extents[direction]
# round pixels up to the nearest 64
pixels = math.ceil(pixels / 64) * 64
print(f">> extending image {direction}ward by {pixels} pixels")
log.info(f"extending image {direction}ward by {pixels} pixels")
image = self._rotate(image, direction)
image = self._extend(image, pixels)
image = self._rotate(image, direction, reverse=True)

View File

@ -6,18 +6,13 @@ import torch
from PIL import Image
from PIL.Image import Image as ImageType
import invokeai.backend.util.logging as log
from invokeai.backend.globals import Globals
class ESRGAN:
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self, denoise_str):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
@ -74,16 +69,16 @@ class ESRGAN:
import sys
import traceback
print(">> Error loading Real-ESRGAN:", file=sys.stderr)
log.error("Error loading Real-ESRGAN:")
print(traceback.format_exc(), file=sys.stderr)
if upsampler_scale == 0:
print(">> Real-ESRGAN: Invalid scaling option. Image not upscaled.")
log.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
return image
if seed is not None:
print(
f">> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
log.info(
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
)
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
image = image.convert("RGB")

View File

@ -14,6 +14,7 @@ from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
import invokeai.backend.util.logging as log
from .globals import global_cache_dir
from .util import CPU_DEVICE
@ -40,8 +41,8 @@ class SafetyChecker(object):
cache_dir=safety_model_path,
)
except Exception:
print(
"** An error was encountered while installing the safety checker:"
log.error(
"An error was encountered while installing the safety checker:"
)
print(traceback.format_exc())
@ -65,8 +66,8 @@ class SafetyChecker(object):
)
self.safety_checker.to(CPU_DEVICE) # offload
if has_nsfw_concept[0]:
print(
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
log.warning(
"An image with potential non-safe content has been detected. A blurred image will be returned."
)
return self.blur(image)
else:

View File

@ -17,6 +17,7 @@ from huggingface_hub import (
hf_hub_url,
)
import invokeai.backend.util.logging as log
from invokeai.backend.globals import Globals
@ -66,11 +67,11 @@ class HuggingFaceConceptsLibrary(object):
# when init, add all in dir. when not init, add only concepts added between init and now
self.concept_list.extend(list(local_concepts_to_add))
except Exception as e:
print(
f" ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
log.warning(
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
)
print(
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
log.warning(
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
)
return self.concept_list
@ -81,7 +82,7 @@ class HuggingFaceConceptsLibrary(object):
be downloaded.
"""
if not concept_name in self.list_concepts():
print(
log.warning(
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
)
return None
@ -219,7 +220,7 @@ class HuggingFaceConceptsLibrary(object):
if chunk == 0:
bytes += total
print(f">> Downloading {repo_id}...", end="")
log.info(f"Downloading {repo_id}...", end="")
try:
for file in (
"README.md",
@ -233,22 +234,22 @@ class HuggingFaceConceptsLibrary(object):
)
except ul_error.HTTPError as e:
if e.code == 404:
print(
log.warning(
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
)
else:
print(
log.warning(
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
)
os.rmdir(dest)
return False
except ul_error.URLError as e:
print(
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
log.error(
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
)
os.rmdir(dest)
return False
print("...{:.2f}Kb".format(bytes / 1024))
log.info("...{:.2f}Kb".format(bytes / 1024))
return succeeded
def _concept_id(self, concept_name: str) -> str:

View File

@ -14,9 +14,9 @@ from diffusers.models.cross_attention import AttnProcessor
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from torch import nn
import invokeai.backend.util.logging as log
from ...util import torch_dtype
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
@ -425,13 +425,13 @@ def get_cross_attention_modules(
expected_count = 16
if cross_attention_modules_in_model_count != expected_count:
# non-fatal error but .swap() won't work.
print(
log.error(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
+ f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
+ f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
+ f"work properly until it is fixed."
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
+ "work properly until it is fixed."
)
return attention_module_tuples

View File

@ -8,6 +8,7 @@ import torch
from diffusers.models.cross_attention import AttnProcessor
from typing_extensions import TypeAlias
import invokeai.backend.util.logging as log
from invokeai.backend.globals import Globals
from .cross_attention_control import (
@ -262,7 +263,7 @@ class InvokeAIDiffuserComponent:
# TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None:
raise ValueError(
f"Either step_index or sigma is required when doing cross attention control, but both are None."
"Either step_index or sigma is required when doing cross attention control, but both are None."
)
percent_through = self.estimate_percent_through(step_index, sigma)
return percent_through
@ -466,10 +467,14 @@ class InvokeAIDiffuserComponent:
outside = torch.count_nonzero(
(latents < -current_threshold) | (latents > current_threshold)
)
print(
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
log.info(
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
)
log.debug(
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
)
log.debug(
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
)
if maxval < current_threshold and minval > -current_threshold:
@ -496,9 +501,11 @@ class InvokeAIDiffuserComponent:
)
if self.debug_thresholding:
print(
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
log.debug(
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
)
log.debug(
f"{num_altered / latents.numel() * 100:.2f}% values altered"
)
return latents
@ -599,7 +606,6 @@ class InvokeAIDiffuserComponent:
)
# below is fugly omg
num_actual_conditionings = len(c_or_weighted_c_list)
conditionings = [uc] + [c for c, weight in weighted_cond_list]
weights = [1] + [weight for c, weight in weighted_cond_list]
chunk_count = ceil(len(conditionings) / 2)

View File

@ -10,7 +10,7 @@ from torchvision.utils import make_grid
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
import invokeai.backend.util.logging as log
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
@ -191,7 +191,7 @@ def mkdirs(paths):
def mkdir_and_rename(path):
if os.path.exists(path):
new_name = path + "_archived_" + get_timestamp()
print("Path already exists. Rename it to [{:s}]".format(new_name))
log.error("Path already exists. Rename it to [{:s}]".format(new_name))
os.replace(path, new_name)
os.makedirs(path)

View File

@ -10,6 +10,7 @@ from compel.embeddings_provider import BaseTextualInversionManager
from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer
import invokeai.backend.util.logging as log
from .concepts_lib import HuggingFaceConceptsLibrary
@dataclass
@ -59,12 +60,12 @@ class TextualInversionManager(BaseTextualInversionManager):
or self.has_textual_inversion_for_trigger_string(concept_name)
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
): # in case a token with literal angle brackets encountered
print(f">> Loaded local embedding for trigger {concept_name}")
log.info(f"Loaded local embedding for trigger {concept_name}")
continue
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
if not bin_file:
continue
print(f">> Loaded remote embedding for trigger {concept_name}")
log.info(f"Loaded remote embedding for trigger {concept_name}")
self.load_textual_inversion(bin_file)
self.hf_concepts_library.concepts_loaded[concept_name] = True
@ -85,8 +86,8 @@ class TextualInversionManager(BaseTextualInversionManager):
embedding_list = self._parse_embedding(str(ckpt_path))
for embedding_info in embedding_list:
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
print(
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
log.warning(
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
)
continue
@ -105,8 +106,8 @@ class TextualInversionManager(BaseTextualInversionManager):
if ckpt_path.name == "learned_embeds.bin"
else f"<{ckpt_path.stem}>"
)
print(
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
log.info(
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
@ -120,8 +121,8 @@ class TextualInversionManager(BaseTextualInversionManager):
self.trigger_to_sourcefile[trigger_str] = sourcefile
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f" | The error was {str(e)}")
log.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
log.debug(f"The error was {str(e)}")
def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False
@ -133,8 +134,8 @@ class TextualInversionManager(BaseTextualInversionManager):
:return: The token id for the added embedding, either existing or newly-added.
"""
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
print(
f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
log.warning(
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
)
return
if not self.full_precision:
@ -155,11 +156,11 @@ class TextualInversionManager(BaseTextualInversionManager):
except ValueError as e:
if str(e).startswith("Warning"):
print(f">> {str(e)}")
log.warning(f"{str(e)}")
else:
traceback.print_exc()
print(
f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
log.error(
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
)
raise
@ -219,16 +220,16 @@ class TextualInversionManager(BaseTextualInversionManager):
for ti in self.textual_inversions:
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
if ti.embedding_vector_length > 1:
print(
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
log.info(
f"Preparing tokens for textual inversion {ti.trigger_string}..."
)
try:
self._inject_tokens_and_assign_embeddings(ti)
except ValueError as e:
print(
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
log.debug(
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
)
print(f" | The error was {str(e)}")
log.debug(f"The error was {str(e)}")
continue
injected_token_ids.append(ti.trigger_token_id)
injected_token_ids.extend(ti.pad_token_ids)
@ -306,16 +307,16 @@ class TextualInversionManager(BaseTextualInversionManager):
if suffix in [".pt",".ckpt",".bin"]:
scan_result = scan_file_path(embedding_file)
if scan_result.infected_files > 0:
print(
f" ** Security Issues Found in Model: {scan_result.issues_count}"
log.critical(
f"Security Issues Found in Model: {scan_result.issues_count}"
)
print(" ** For your safety, InvokeAI will not load this embed.")
log.critical("For your safety, InvokeAI will not load this embed.")
return list()
ckpt = torch.load(embedding_file,map_location="cpu")
else:
ckpt = safetensors.torch.load_file(embedding_file)
except Exception as e:
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
log.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
return list()
# try to figure out what kind of embedding file it is and parse accordingly
@ -334,7 +335,7 @@ class TextualInversionManager(BaseTextualInversionManager):
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
basename = Path(file_path).stem
print(f' | Loading v1 embedding file: {basename}')
log.debug(f'Loading v1 embedding file: {basename}')
embeddings = list()
token_counter = -1
@ -342,7 +343,7 @@ class TextualInversionManager(BaseTextualInversionManager):
if token_counter < 0:
trigger = embedding_ckpt["name"]
elif token_counter == 0:
trigger = f'<basename>'
trigger = '<basename>'
else:
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
token_counter += 1
@ -365,7 +366,7 @@ class TextualInversionManager(BaseTextualInversionManager):
This handles embedding .pt file variant #2.
"""
basename = Path(file_path).stem
print(f' | Loading v2 embedding file: {basename}')
log.debug(f'Loading v2 embedding file: {basename}')
embeddings = list()
if isinstance(
@ -384,7 +385,7 @@ class TextualInversionManager(BaseTextualInversionManager):
)
embeddings.append(embedding_info)
else:
print(f" ** {basename}: Unrecognized embedding format")
log.warning(f"{basename}: Unrecognized embedding format")
return embeddings
@ -393,7 +394,7 @@ class TextualInversionManager(BaseTextualInversionManager):
Parse 'version 3' of the .pt textual inversion embedding files.
"""
basename = Path(file_path).stem
print(f' | Loading v3 embedding file: {basename}')
log.debug(f'Loading v3 embedding file: {basename}')
embedding = embedding_ckpt['emb_params']
embedding_info = EmbeddingInfo(
name = f'<{basename}>',
@ -411,11 +412,11 @@ class TextualInversionManager(BaseTextualInversionManager):
basename = Path(filepath).stem
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
print(f' | Loading v4 embedding file: {short_path}')
log.debug(f'Loading v4 embedding file: {short_path}')
embeddings = list()
if list(embedding_ckpt.keys()) == 0:
print(f" ** Invalid embeddings file: {short_path}")
log.warning(f"Invalid embeddings file: {short_path}")
else:
for token,embedding in embedding_ckpt.items():
embedding_info = EmbeddingInfo(

View File

@ -18,6 +18,7 @@ import torch
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import invokeai.backend.util.logging as log
from .devices import torch_dtype
@ -38,7 +39,7 @@ def log_txt_as_img(wh, xc, size=10):
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
log.warning("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@ -80,8 +81,8 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
log.debug(
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
)
return total_params
@ -132,8 +133,8 @@ def parallel_data_prefetch(
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
log.warning(
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
@ -175,7 +176,7 @@ def parallel_data_prefetch(
processes += [p]
# start processes
print("Start prefetching...")
log.info("Start prefetching...")
import time
start = time.time()
@ -194,7 +195,7 @@ def parallel_data_prefetch(
gather_res[res[0]] = res[1]
except Exception as e:
print("Exception: ", e)
log.error("Exception: ", e)
for p in processes:
p.terminate()
@ -202,7 +203,7 @@ def parallel_data_prefetch(
finally:
for p in processes:
p.join()
print(f"Prefetching complete. [{time.time() - start} sec.]")
log.info(f"Prefetching complete. [{time.time() - start} sec.]")
if target_data_type == "ndarray":
if not isinstance(gather_res[0], np.ndarray):
@ -318,23 +319,23 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
resp = requests.get(url, headers=header, stream=True) # new request with range
if exist_size > content_length:
print("* corrupt existing file found. re-downloading")
log.warning("corrupt existing file found. re-downloading")
os.remove(dest)
exist_size = 0
if resp.status_code == 416 or exist_size == content_length:
print(f"* {dest}: complete file found. Skipping.")
log.warning(f"{dest}: complete file found. Skipping.")
return dest
elif resp.status_code == 206 or exist_size > 0:
print(f"* {dest}: partial file found. Resuming...")
log.warning(f"{dest}: partial file found. Resuming...")
elif resp.status_code != 200:
print(f"** An error occurred during downloading {dest}: {resp.reason}")
log.error(f"An error occurred during downloading {dest}: {resp.reason}")
else:
print(f"* {dest}: Downloading...")
log.error(f"{dest}: Downloading...")
try:
if content_length < 2000:
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
log.error(f"ERROR DOWNLOADING {url}: {resp.text}")
return None
with open(dest, open_mode) as file, tqdm(
@ -349,7 +350,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
size = file.write(data)
bar.update(size)
except Exception as e:
print(f"An error occurred while downloading {dest}: {str(e)}")
log.error(f"An error occurred while downloading {dest}: {str(e)}")
return None
return dest

View File

@ -19,6 +19,7 @@ from PIL import Image
from PIL.Image import Image as ImageType
from werkzeug.utils import secure_filename
import invokeai.backend.util.logging as log
import invokeai.frontend.web.dist as frontend
from .. import Generate
@ -213,7 +214,7 @@ class InvokeAIWebServer:
self.load_socketio_listeners(self.socketio)
if args.gui:
print(">> Launching Invoke AI GUI")
log.info("Launching Invoke AI GUI")
try:
from flaskwebgui import FlaskUI
@ -231,17 +232,17 @@ class InvokeAIWebServer:
sys.exit(0)
else:
useSSL = args.certfile or args.keyfile
print(">> Started Invoke AI Web Server")
log.info("Started Invoke AI Web Server")
if self.host == "0.0.0.0":
print(
log.info(
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
)
else:
print(
">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
log.info(
"Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
)
print(
f">> Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
log.info(
f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
)
if not useSSL:
self.socketio.run(app=self.app, host=self.host, port=self.port)
@ -290,7 +291,7 @@ class InvokeAIWebServer:
def load_socketio_listeners(self, socketio):
@socketio.on("requestSystemConfig")
def handle_request_capabilities():
print(">> System config requested")
log.info("System config requested")
config = self.get_system_config()
config["model_list"] = self.generate.model_manager.list_models()
config["infill_methods"] = infill_methods()
@ -330,7 +331,7 @@ class InvokeAIWebServer:
if model_name in current_model_list:
update = True
print(f">> Adding New Model: {model_name}")
log.info(f"Adding New Model: {model_name}")
self.generate.model_manager.add_model(
model_name=model_name,
@ -348,14 +349,14 @@ class InvokeAIWebServer:
"update": update,
},
)
print(f">> New Model Added: {model_name}")
log.info(f"New Model Added: {model_name}")
except Exception as e:
self.handle_exceptions(e)
@socketio.on("deleteModel")
def handle_delete_model(model_name: str):
try:
print(f">> Deleting Model: {model_name}")
log.info(f"Deleting Model: {model_name}")
self.generate.model_manager.del_model(model_name)
self.generate.model_manager.commit(opt.conf)
updated_model_list = self.generate.model_manager.list_models()
@ -366,14 +367,14 @@ class InvokeAIWebServer:
"model_list": updated_model_list,
},
)
print(f">> Model Deleted: {model_name}")
log.info(f"Model Deleted: {model_name}")
except Exception as e:
self.handle_exceptions(e)
@socketio.on("requestModelChange")
def handle_set_model(model_name: str):
try:
print(f">> Model change requested: {model_name}")
log.info(f"Model change requested: {model_name}")
model = self.generate.set_model(model_name)
model_list = self.generate.model_manager.list_models()
if model is None:
@ -454,7 +455,7 @@ class InvokeAIWebServer:
"update": True,
},
)
print(f">> Model Converted: {model_name}")
log.info(f"Model Converted: {model_name}")
except Exception as e:
self.handle_exceptions(e)
@ -490,7 +491,7 @@ class InvokeAIWebServer:
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
"vae", None
):
print(f">> Using configured VAE assigned to {models_to_merge[0]}")
log.info(f"Using configured VAE assigned to {models_to_merge[0]}")
merged_model_config.update(vae=vae)
self.generate.model_manager.import_diffuser_model(
@ -507,8 +508,8 @@ class InvokeAIWebServer:
"update": True,
},
)
print(f">> Models Merged: {models_to_merge}")
print(f">> New Model Added: {model_merge_info['merged_model_name']}")
log.info(f"Models Merged: {models_to_merge}")
log.info(f"New Model Added: {model_merge_info['merged_model_name']}")
except Exception as e:
self.handle_exceptions(e)
@ -698,7 +699,7 @@ class InvokeAIWebServer:
}
)
except Exception as e:
print(f">> Unable to load {path}")
log.info(f"Unable to load {path}")
socketio.emit(
"error", {"message": f"Unable to load {path}: {str(e)}"}
)
@ -735,9 +736,9 @@ class InvokeAIWebServer:
printable_parameters["init_mask"][:64] + "..."
)
print(f"\n>> Image Generation Parameters:\n\n{printable_parameters}\n")
print(f">> ESRGAN Parameters: {esrgan_parameters}")
print(f">> Facetool Parameters: {facetool_parameters}")
log.info(f"Image Generation Parameters:\n\n{printable_parameters}\n")
log.info(f"ESRGAN Parameters: {esrgan_parameters}")
log.info(f"Facetool Parameters: {facetool_parameters}")
self.generate_images(
generation_parameters,
@ -750,8 +751,8 @@ class InvokeAIWebServer:
@socketio.on("runPostprocessing")
def handle_run_postprocessing(original_image, postprocessing_parameters):
try:
print(
f'>> Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
log.info(
f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
)
progress = Progress()
@ -861,14 +862,14 @@ class InvokeAIWebServer:
@socketio.on("cancel")
def handle_cancel():
print(">> Cancel processing requested")
log.info("Cancel processing requested")
self.canceled.set()
# TODO: I think this needs a safety mechanism.
@socketio.on("deleteImage")
def handle_delete_image(url, thumbnail, uuid, category):
try:
print(f'>> Delete requested "{url}"')
log.info(f'Delete requested "{url}"')
from send2trash import send2trash
path = self.get_image_path_from_url(url)
@ -1263,7 +1264,7 @@ class InvokeAIWebServer:
image, os.path.basename(path), self.thumbnail_image_path
)
print(f'\n\n>> Image generated: "{path}"\n')
log.info(f'Image generated: "{path}"\n')
self.write_log_message(f'[Generated] "{path}": {command}')
if progress.total_iterations > progress.current_iteration:
@ -1329,7 +1330,7 @@ class InvokeAIWebServer:
except Exception as e:
# Clear the CUDA cache on an exception
self.empty_cuda_cache()
print(e)
log.error(e)
self.handle_exceptions(e)
def empty_cuda_cache(self):