mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add logging support
This commit adds invokeai.backend.util.logging, which provides support for formatted console and logfile messages that follow the status reporting conventions of earlier InvokeAI versions. Examples: ### A critical error (logging.CRITICAL) *** A non-fatal error (logging.ERROR) ** A warning (logging.WARNING) >> Informational message (logging.INFO) | Debugging message (logging.DEBUG)
This commit is contained in:
parent
d923d1d66b
commit
5a4765046e
@ -24,6 +24,7 @@ import safetensors
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
import invokeai.backend.util.logging as ialog
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
@ -132,8 +133,8 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not self.valid_model(model_name):
|
if not self.valid_model(model_name):
|
||||||
print(
|
ialog.error(
|
||||||
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||||
)
|
)
|
||||||
return self.current_model
|
return self.current_model
|
||||||
|
|
||||||
@ -144,7 +145,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
if model_name in self.models:
|
if model_name in self.models:
|
||||||
requested_model = self.models[model_name]["model"]
|
requested_model = self.models[model_name]["model"]
|
||||||
print(f">> Retrieving model {model_name} from system RAM cache")
|
ialog.info(f"Retrieving model {model_name} from system RAM cache")
|
||||||
requested_model.ready()
|
requested_model.ready()
|
||||||
width = self.models[model_name]["width"]
|
width = self.models[model_name]["width"]
|
||||||
height = self.models[model_name]["height"]
|
height = self.models[model_name]["height"]
|
||||||
@ -379,7 +380,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
omega = self.config
|
omega = self.config
|
||||||
if model_name not in omega:
|
if model_name not in omega:
|
||||||
print(f"** Unknown model {model_name}")
|
ialog.error(f"Unknown model {model_name}")
|
||||||
return
|
return
|
||||||
# save these for use in deletion later
|
# save these for use in deletion later
|
||||||
conf = omega[model_name]
|
conf = omega[model_name]
|
||||||
@ -392,13 +393,13 @@ class ModelManager(object):
|
|||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
if delete_files:
|
if delete_files:
|
||||||
if weights:
|
if weights:
|
||||||
print(f"** Deleting file {weights}")
|
ialog.info(f"Deleting file {weights}")
|
||||||
Path(weights).unlink(missing_ok=True)
|
Path(weights).unlink(missing_ok=True)
|
||||||
elif path:
|
elif path:
|
||||||
print(f"** Deleting directory {path}")
|
ialog.info(f"Deleting directory {path}")
|
||||||
rmtree(path, ignore_errors=True)
|
rmtree(path, ignore_errors=True)
|
||||||
elif repo_id:
|
elif repo_id:
|
||||||
print(f"** Deleting the cached model directory for {repo_id}")
|
ialog.info(f"Deleting the cached model directory for {repo_id}")
|
||||||
self._delete_model_from_cache(repo_id)
|
self._delete_model_from_cache(repo_id)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
@ -439,7 +440,7 @@ class ModelManager(object):
|
|||||||
def _load_model(self, model_name: str):
|
def _load_model(self, model_name: str):
|
||||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||||
if model_name not in self.config:
|
if model_name not in self.config:
|
||||||
print(
|
ialog.error(
|
||||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -457,7 +458,7 @@ class ModelManager(object):
|
|||||||
model_format = mconfig.get("format", "ckpt")
|
model_format = mconfig.get("format", "ckpt")
|
||||||
if model_format == "ckpt":
|
if model_format == "ckpt":
|
||||||
weights = mconfig.weights
|
weights = mconfig.weights
|
||||||
print(f">> Loading {model_name} from {weights}")
|
ialog.info(f"Loading {model_name} from {weights}")
|
||||||
model, width, height, model_hash = self._load_ckpt_model(
|
model, width, height, model_hash = self._load_ckpt_model(
|
||||||
model_name, mconfig
|
model_name, mconfig
|
||||||
)
|
)
|
||||||
@ -473,13 +474,15 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# usage statistics
|
# usage statistics
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print(">> Model loaded in", "%4.2fs" % (toc - tic))
|
ialog.info("Model loaded in " + "%4.2fs" % (toc - tic))
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
print(
|
ialog.info(
|
||||||
">> Max VRAM used to load the model:",
|
"Max VRAM used to load the model: "+
|
||||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9),
|
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
|
||||||
"\n>> Current VRAM usage:"
|
)
|
||||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
ialog.info(
|
||||||
|
"Current VRAM usage: "+
|
||||||
|
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
||||||
)
|
)
|
||||||
return model, width, height, model_hash
|
return model, width, height, model_hash
|
||||||
|
|
||||||
@ -487,11 +490,11 @@ class ModelManager(object):
|
|||||||
name_or_path = self.model_name_or_path(mconfig)
|
name_or_path = self.model_name_or_path(mconfig)
|
||||||
using_fp16 = self.precision == "float16"
|
using_fp16 = self.precision == "float16"
|
||||||
|
|
||||||
print(f">> Loading diffusers model from {name_or_path}")
|
ialog.info(f"Loading diffusers model from {name_or_path}")
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
print(" | Using faster float16 precision")
|
ialog.debug("Using faster float16 precision")
|
||||||
else:
|
else:
|
||||||
print(" | Using more accurate float32 precision")
|
ialog.debug("Using more accurate float32 precision")
|
||||||
|
|
||||||
# TODO: scan weights maybe?
|
# TODO: scan weights maybe?
|
||||||
pipeline_args: dict[str, Any] = dict(
|
pipeline_args: dict[str, Any] = dict(
|
||||||
@ -523,8 +526,8 @@ class ModelManager(object):
|
|||||||
if str(e).startswith("fp16 is not a valid"):
|
if str(e).startswith("fp16 is not a valid"):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print(
|
ialog.error(
|
||||||
f"** An unexpected error occurred while downloading the model: {e})"
|
f"An unexpected error occurred while downloading the model: {e})"
|
||||||
)
|
)
|
||||||
if pipeline:
|
if pipeline:
|
||||||
break
|
break
|
||||||
@ -542,7 +545,7 @@ class ModelManager(object):
|
|||||||
# square images???
|
# square images???
|
||||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||||
height = width
|
height = width
|
||||||
print(f" | Default image dimensions = {width} x {height}")
|
ialog.debug(f"Default image dimensions = {width} x {height}")
|
||||||
|
|
||||||
return pipeline, width, height, model_hash
|
return pipeline, width, height, model_hash
|
||||||
|
|
||||||
@ -559,14 +562,14 @@ class ModelManager(object):
|
|||||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||||
|
|
||||||
# Convert to diffusers and return a diffusers pipeline
|
# Convert to diffusers and return a diffusers pipeline
|
||||||
print(f">> Converting legacy checkpoint {model_name} into a diffusers model...")
|
ialog.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||||
|
|
||||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.list_models()[self.current_model]["status"] == "active":
|
if self.list_models()[self.current_model]["status"] == "active":
|
||||||
self.offload_model(self.current_model)
|
self.offload_model(self.current_model)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
vae_path = None
|
vae_path = None
|
||||||
@ -624,7 +627,7 @@ class ModelManager(object):
|
|||||||
if model_name not in self.models:
|
if model_name not in self.models:
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f">> Offloading {model_name} to CPU")
|
ialog.info(f"Offloading {model_name} to CPU")
|
||||||
model = self.models[model_name]["model"]
|
model = self.models[model_name]["model"]
|
||||||
model.offload_all()
|
model.offload_all()
|
||||||
self.current_model = None
|
self.current_model = None
|
||||||
@ -640,30 +643,26 @@ class ModelManager(object):
|
|||||||
and option to exit if an infected file is identified.
|
and option to exit if an infected file is identified.
|
||||||
"""
|
"""
|
||||||
# scan model
|
# scan model
|
||||||
print(f" | Scanning Model: {model_name}")
|
ialog.debug(f"Scanning Model: {model_name}")
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
print(f"\n### Issues Found In Model: {scan_result.issues_count}")
|
ialog.critical(f"Issues Found In Model: {scan_result.issues_count}")
|
||||||
print(
|
ialog.critical("The model you are trying to load seems to be infected.")
|
||||||
"### WARNING: The model you are trying to load seems to be infected."
|
ialog.critical("For your safety, InvokeAI will not load this model.")
|
||||||
)
|
ialog.critical("Please use checkpoints from trusted sources.")
|
||||||
print("### For your safety, InvokeAI will not load this model.")
|
ialog.critical("Exiting InvokeAI")
|
||||||
print("### Please use checkpoints from trusted sources.")
|
|
||||||
print("### Exiting InvokeAI")
|
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
print(
|
ialog.warning("InvokeAI was unable to scan the model you are using.")
|
||||||
"\n### WARNING: InvokeAI was unable to scan the model you are using."
|
|
||||||
)
|
|
||||||
model_safe_check_fail = ask_user(
|
model_safe_check_fail = ask_user(
|
||||||
"Do you want to to continue loading the model?", ["y", "n"]
|
"Do you want to to continue loading the model?", ["y", "n"]
|
||||||
)
|
)
|
||||||
if model_safe_check_fail.lower() != "y":
|
if model_safe_check_fail.lower() != "y":
|
||||||
print("### Exiting InvokeAI")
|
ialog.critical("Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
print(" | Model scanned ok")
|
ialog.debug("Model scanned ok")
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
@ -780,26 +779,24 @@ class ModelManager(object):
|
|||||||
model_path: Path = None
|
model_path: Path = None
|
||||||
thing = path_url_or_repo # to save typing
|
thing = path_url_or_repo # to save typing
|
||||||
|
|
||||||
print(f">> Probing {thing} for import")
|
ialog.info(f"Probing {thing} for import")
|
||||||
|
|
||||||
if thing.startswith(("http:", "https:", "ftp:")):
|
if thing.startswith(("http:", "https:", "ftp:")):
|
||||||
print(f" | {thing} appears to be a URL")
|
ialog.info(f"{thing} appears to be a URL")
|
||||||
model_path = self._resolve_path(
|
model_path = self._resolve_path(
|
||||||
thing, "models/ldm/stable-diffusion-v1"
|
thing, "models/ldm/stable-diffusion-v1"
|
||||||
) # _resolve_path does a download if needed
|
) # _resolve_path does a download if needed
|
||||||
|
|
||||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||||
print(
|
ialog.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
|
||||||
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
print(f" | {thing} appears to be a checkpoint file on disk")
|
ialog.debug(f"{thing} appears to be a checkpoint file on disk")
|
||||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||||
|
|
||||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||||
print(f" | {thing} appears to be a diffusers file on disk")
|
ialog.debug(f"{thing} appears to be a diffusers file on disk")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing,
|
thing,
|
||||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||||
@ -810,34 +807,30 @@ class ModelManager(object):
|
|||||||
|
|
||||||
elif Path(thing).is_dir():
|
elif Path(thing).is_dir():
|
||||||
if (Path(thing) / "model_index.json").exists():
|
if (Path(thing) / "model_index.json").exists():
|
||||||
print(f" | {thing} appears to be a diffusers model.")
|
ialog.debug(f"{thing} appears to be a diffusers model.")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing, commit_to_conf=commit_to_conf
|
thing, commit_to_conf=commit_to_conf
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
ialog.debug(f"{thing} appears to be a directory. Will scan for models to import")
|
||||||
f" |{thing} appears to be a directory. Will scan for models to import"
|
|
||||||
)
|
|
||||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
||||||
Path(thing).rglob("*.safetensors")
|
Path(thing).rglob("*.safetensors")
|
||||||
):
|
):
|
||||||
if model_name := self.heuristic_import(
|
if model_name := self.heuristic_import(
|
||||||
str(m), commit_to_conf=commit_to_conf
|
str(m), commit_to_conf=commit_to_conf
|
||||||
):
|
):
|
||||||
print(f" >> {model_name} successfully imported")
|
ialog.info(f"{model_name} successfully imported")
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
||||||
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
ialog.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing, commit_to_conf=commit_to_conf
|
thing, commit_to_conf=commit_to_conf
|
||||||
)
|
)
|
||||||
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
||||||
return model_name
|
return model_name
|
||||||
else:
|
else:
|
||||||
print(
|
ialog.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
||||||
f"** {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Model_path is set in the event of a legacy checkpoint file.
|
# Model_path is set in the event of a legacy checkpoint file.
|
||||||
# If not set, we're all done
|
# If not set, we're all done
|
||||||
@ -845,7 +838,7 @@ class ModelManager(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if model_path.stem in self.config: # already imported
|
if model_path.stem in self.config: # already imported
|
||||||
print(" | Already imported. Skipping")
|
ialog.debug("Already imported. Skipping")
|
||||||
return model_path.stem
|
return model_path.stem
|
||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
@ -861,39 +854,39 @@ class ModelManager(object):
|
|||||||
# look for a like-named .yaml file in same directory
|
# look for a like-named .yaml file in same directory
|
||||||
if model_path.with_suffix(".yaml").exists():
|
if model_path.with_suffix(".yaml").exists():
|
||||||
model_config_file = model_path.with_suffix(".yaml")
|
model_config_file = model_path.with_suffix(".yaml")
|
||||||
print(f" | Using config file {model_config_file.name}")
|
ialog.debug(f"Using config file {model_config_file.name}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model_type = self.probe_model_type(checkpoint)
|
model_type = self.probe_model_type(checkpoint)
|
||||||
if model_type == SDLegacyType.V1:
|
if model_type == SDLegacyType.V1:
|
||||||
print(" | SD-v1 model detected")
|
ialog.debug("SD-v1 model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V1_INPAINT:
|
elif model_type == SDLegacyType.V1_INPAINT:
|
||||||
print(" | SD-v1 inpainting model detected")
|
ialog.debug("SD-v1 inpainting model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root,
|
Globals.root,
|
||||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_v:
|
elif model_type == SDLegacyType.V2_v:
|
||||||
print(" | SD-v2-v model detected")
|
ialog.debug("SD-v2-v model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_e:
|
elif model_type == SDLegacyType.V2_e:
|
||||||
print(" | SD-v2-e model detected")
|
ialog.debug("SD-v2-e model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2:
|
elif model_type == SDLegacyType.V2:
|
||||||
print(
|
ialog.warning(
|
||||||
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
print(
|
ialog.warning(
|
||||||
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -909,7 +902,7 @@ class ModelManager(object):
|
|||||||
for suffix in ["pt", "ckpt", "safetensors"]:
|
for suffix in ["pt", "ckpt", "safetensors"]:
|
||||||
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
||||||
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
||||||
print(f" | Using VAE file {vae_path.name}")
|
ialog.debug(f"Using VAE file {vae_path.name}")
|
||||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||||
|
|
||||||
diffuser_path = Path(
|
diffuser_path = Path(
|
||||||
@ -955,14 +948,14 @@ class ModelManager(object):
|
|||||||
from . import convert_ckpt_to_diffusers
|
from . import convert_ckpt_to_diffusers
|
||||||
|
|
||||||
if diffusers_path.exists():
|
if diffusers_path.exists():
|
||||||
print(
|
ialog.error(
|
||||||
f"ERROR: The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
model_name = model_name or diffusers_path.name
|
model_name = model_name or diffusers_path.name
|
||||||
model_description = model_description or f"Converted version of {model_name}"
|
model_description = model_description or f"Converted version of {model_name}"
|
||||||
print(f" | Converting {model_name} to diffusers (30-60s)")
|
ialog.debug(f"Converting {model_name} to diffusers (30-60s)")
|
||||||
try:
|
try:
|
||||||
# By passing the specified VAE to the conversion function, the autoencoder
|
# By passing the specified VAE to the conversion function, the autoencoder
|
||||||
# will be built into the model rather than tacked on afterward via the config file
|
# will be built into the model rather than tacked on afterward via the config file
|
||||||
@ -979,10 +972,10 @@ class ModelManager(object):
|
|||||||
vae_path=vae_path,
|
vae_path=vae_path,
|
||||||
scan_needed=scan_needed,
|
scan_needed=scan_needed,
|
||||||
)
|
)
|
||||||
print(
|
ialog.debug(
|
||||||
f" | Success. Converted model is now located at {str(diffusers_path)}"
|
f"Success. Converted model is now located at {str(diffusers_path)}"
|
||||||
)
|
)
|
||||||
print(f" | Writing new config file entry for {model_name}")
|
ialog.debug(f"Writing new config file entry for {model_name}")
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
path=str(diffusers_path),
|
path=str(diffusers_path),
|
||||||
description=model_description,
|
description=model_description,
|
||||||
@ -993,17 +986,17 @@ class ModelManager(object):
|
|||||||
self.add_model(model_name, new_config, True)
|
self.add_model(model_name, new_config, True)
|
||||||
if commit_to_conf:
|
if commit_to_conf:
|
||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
print(" | Conversion succeeded")
|
ialog.debug("Conversion succeeded")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"** Conversion failed: {str(e)}")
|
ialog.warning(f"Conversion failed: {str(e)}")
|
||||||
print(
|
ialog.warning(
|
||||||
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def search_models(self, search_folder):
|
def search_models(self, search_folder):
|
||||||
print(f">> Finding Models In: {search_folder}")
|
ialog.info(f"Finding Models In: {search_folder}")
|
||||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||||
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
||||||
|
|
||||||
@ -1027,8 +1020,8 @@ class ModelManager(object):
|
|||||||
num_loaded_models = len(self.models)
|
num_loaded_models = len(self.models)
|
||||||
if num_loaded_models >= self.max_loaded_models:
|
if num_loaded_models >= self.max_loaded_models:
|
||||||
least_recent_model = self._pop_oldest_model()
|
least_recent_model = self._pop_oldest_model()
|
||||||
print(
|
ialog.info(
|
||||||
f">> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
||||||
)
|
)
|
||||||
if least_recent_model is not None:
|
if least_recent_model is not None:
|
||||||
del self.models[least_recent_model]
|
del self.models[least_recent_model]
|
||||||
@ -1036,8 +1029,8 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def print_vram_usage(self) -> None:
|
def print_vram_usage(self) -> None:
|
||||||
if self._has_cuda:
|
if self._has_cuda:
|
||||||
print(
|
ialog.info(
|
||||||
">> Current VRAM usage: ",
|
"Current VRAM usage:"+
|
||||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1126,10 +1119,10 @@ class ModelManager(object):
|
|||||||
dest = hub / model.stem
|
dest = hub / model.stem
|
||||||
if dest.exists() and not source.exists():
|
if dest.exists() and not source.exists():
|
||||||
continue
|
continue
|
||||||
print(f"** {source} => {dest}")
|
ialog.info(f"{source} => {dest}")
|
||||||
if source.exists():
|
if source.exists():
|
||||||
if dest.is_symlink():
|
if dest.is_symlink():
|
||||||
print(f"** Found symlink at {dest.name}. Not migrating.")
|
ialog.warning(f"Found symlink at {dest.name}. Not migrating.")
|
||||||
elif dest.exists():
|
elif dest.exists():
|
||||||
if source.is_dir():
|
if source.is_dir():
|
||||||
rmtree(source)
|
rmtree(source)
|
||||||
@ -1146,7 +1139,7 @@ class ModelManager(object):
|
|||||||
]
|
]
|
||||||
for d in empty:
|
for d in empty:
|
||||||
os.rmdir(d)
|
os.rmdir(d)
|
||||||
print("** Migration is done. Continuing...")
|
ialog.info("Migration is done. Continuing...")
|
||||||
|
|
||||||
def _resolve_path(
|
def _resolve_path(
|
||||||
self, source: Union[str, Path], dest_directory: str
|
self, source: Union[str, Path], dest_directory: str
|
||||||
@ -1189,15 +1182,15 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
print(f">> Loading embeddings from {self.embedding_path}")
|
ialog.info(f"Loading embeddings from {self.embedding_path}")
|
||||||
for root, _, files in os.walk(self.embedding_path):
|
for root, _, files in os.walk(self.embedding_path):
|
||||||
for name in files:
|
for name in files:
|
||||||
ti_path = os.path.join(root, name)
|
ti_path = os.path.join(root, name)
|
||||||
model.textual_inversion_manager.load_textual_inversion(
|
model.textual_inversion_manager.load_textual_inversion(
|
||||||
ti_path, defer_injecting_tokens=True
|
ti_path, defer_injecting_tokens=True
|
||||||
)
|
)
|
||||||
print(
|
ialog.info(
|
||||||
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||||
)
|
)
|
||||||
|
|
||||||
def _has_cuda(self) -> bool:
|
def _has_cuda(self) -> bool:
|
||||||
@ -1219,7 +1212,7 @@ class ModelManager(object):
|
|||||||
with open(hashpath) as f:
|
with open(hashpath) as f:
|
||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
print(" | Calculating sha256 hash of model files")
|
ialog.debug("Calculating sha256 hash of model files")
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
count = 0
|
count = 0
|
||||||
@ -1231,7 +1224,7 @@ class ModelManager(object):
|
|||||||
sha.update(chunk)
|
sha.update(chunk)
|
||||||
hash = sha.hexdigest()
|
hash = sha.hexdigest()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
ialog.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||||
with open(hashpath, "w") as f:
|
with open(hashpath, "w") as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
@ -1249,13 +1242,13 @@ class ModelManager(object):
|
|||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
print(" | Calculating sha256 hash of weights file")
|
ialog.debug("Calculating sha256 hash of weights file")
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
sha.update(data)
|
sha.update(data)
|
||||||
hash = sha.hexdigest()
|
hash = sha.hexdigest()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print(f">> sha256 = {hash}", "(%4.2fs)" % (toc - tic))
|
ialog.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
|
||||||
|
|
||||||
with open(hashpath, "w") as f:
|
with open(hashpath, "w") as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
@ -1276,12 +1269,12 @@ class ModelManager(object):
|
|||||||
local_files_only=not Globals.internet_available,
|
local_files_only=not Globals.internet_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f" | Loading diffusers VAE from {name_or_path}")
|
ialog.debug(f"Loading diffusers VAE from {name_or_path}")
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
vae_args.update(torch_dtype=torch.float16)
|
vae_args.update(torch_dtype=torch.float16)
|
||||||
fp_args_list = [{"revision": "fp16"}, {}]
|
fp_args_list = [{"revision": "fp16"}, {}]
|
||||||
else:
|
else:
|
||||||
print(" | Using more accurate float32 precision")
|
ialog.debug("Using more accurate float32 precision")
|
||||||
fp_args_list = [{}]
|
fp_args_list = [{}]
|
||||||
|
|
||||||
vae = None
|
vae = None
|
||||||
@ -1305,7 +1298,7 @@ class ModelManager(object):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not vae and deferred_error:
|
if not vae and deferred_error:
|
||||||
print(f"** Could not load VAE {name_or_path}: {str(deferred_error)}")
|
ialog.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||||
|
|
||||||
return vae
|
return vae
|
||||||
|
|
||||||
@ -1321,8 +1314,8 @@ class ModelManager(object):
|
|||||||
for revision in repo.revisions:
|
for revision in repo.revisions:
|
||||||
hashes_to_delete.add(revision.commit_hash)
|
hashes_to_delete.add(revision.commit_hash)
|
||||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||||
print(
|
ialog.warning(
|
||||||
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||||
)
|
)
|
||||||
strategy.execute()
|
strategy.execute()
|
||||||
|
|
||||||
|
107
invokeai/backend/util/logging.py
Normal file
107
invokeai/backend/util/logging.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
"""invokeai.util.logging
|
||||||
|
Copyright 2023 The InvokeAI Development Team
|
||||||
|
|
||||||
|
Logging class for InvokeAI that produces console messages that follow
|
||||||
|
the conventions established in InvokeAI 1.X through 2.X.
|
||||||
|
|
||||||
|
|
||||||
|
One way to use it:
|
||||||
|
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
logger = InvokeAILogger.getLogger(__name__)
|
||||||
|
logger.critical('this is critical')
|
||||||
|
logger.error('this is an error')
|
||||||
|
logger.warning('this is a warning')
|
||||||
|
logger.info('this is info')
|
||||||
|
logger.debug('this is debugging')
|
||||||
|
|
||||||
|
Console messages:
|
||||||
|
### this is critical
|
||||||
|
*** this is an error ***
|
||||||
|
** this is a warning
|
||||||
|
>> this is info
|
||||||
|
| this is debugging
|
||||||
|
|
||||||
|
Another way:
|
||||||
|
import invokeai.backend.util.logging as ialog
|
||||||
|
ialog.debug('this is a debugging message')
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def debug(msg:str):
|
||||||
|
InvokeAILogger.getLogger().debug(msg)
|
||||||
|
|
||||||
|
def info(msg:str):
|
||||||
|
InvokeAILogger.getLogger().info(msg)
|
||||||
|
|
||||||
|
def warning(msg:str):
|
||||||
|
InvokeAILogger.getLogger().warning(msg)
|
||||||
|
|
||||||
|
def error(msg:str):
|
||||||
|
InvokeAILogger.getLogger().error(msg)
|
||||||
|
|
||||||
|
def critical(msg:str):
|
||||||
|
InvokeAILogger.getLogger().critical(msg)
|
||||||
|
|
||||||
|
class InvokeAILogFormatter(logging.Formatter):
|
||||||
|
'''
|
||||||
|
Repurposed from:
|
||||||
|
https://stackoverflow.com/questions/14844970/modifying-logging-message-format-based-on-message-logging-level-in-python3
|
||||||
|
'''
|
||||||
|
crit_fmt = "### %(msg)s"
|
||||||
|
err_fmt = "!!! %(msg)s !!!"
|
||||||
|
warn_fmt = "** %(msg)s"
|
||||||
|
info_fmt = ">> %(msg)s"
|
||||||
|
dbg_fmt = " | %(msg)s"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(fmt="%(levelno)d: %(msg)s", datefmt=None, style='%')
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
# Remember the format used when the logging module
|
||||||
|
# was installed (in the event that this formatter is
|
||||||
|
# used with the vanilla logging module.
|
||||||
|
format_orig = self._style._fmt
|
||||||
|
if record.levelno == logging.DEBUG:
|
||||||
|
self._style._fmt = InvokeAILogFormatter.dbg_fmt
|
||||||
|
if record.levelno == logging.INFO:
|
||||||
|
self._style._fmt = InvokeAILogFormatter.info_fmt
|
||||||
|
if record.levelno == logging.WARNING:
|
||||||
|
self._style._fmt = InvokeAILogFormatter.warn_fmt
|
||||||
|
if record.levelno == logging.ERROR:
|
||||||
|
self._style._fmt = InvokeAILogFormatter.err_fmt
|
||||||
|
if record.levelno == logging.CRITICAL:
|
||||||
|
self._style._fmt = InvokeAILogFormatter.crit_fmt
|
||||||
|
|
||||||
|
# parent class does the work
|
||||||
|
result = super().format(record)
|
||||||
|
self._style._fmt = format_orig
|
||||||
|
return result
|
||||||
|
|
||||||
|
class InvokeAILogger(object):
|
||||||
|
loggers = dict()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getLogger(self, name:str='invokeai')->logging.Logger:
|
||||||
|
if name not in self.loggers:
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
ch = logging.StreamHandler()
|
||||||
|
fmt = InvokeAILogFormatter()
|
||||||
|
ch.setFormatter(fmt)
|
||||||
|
logger.addHandler(ch)
|
||||||
|
self.loggers[name] = logger
|
||||||
|
return self.loggers[name]
|
||||||
|
|
||||||
|
def test():
|
||||||
|
logger = InvokeAILogger.getLogger('foobar')
|
||||||
|
logger.info('InvokeAI initialized')
|
||||||
|
logger.info('Running on GPU 14')
|
||||||
|
logger.info('Loading model foobar')
|
||||||
|
logger.debug('scanning for malware')
|
||||||
|
logger.debug('combobulating')
|
||||||
|
logger.warning('Oops. This model is strange.')
|
||||||
|
logger.error('Bailing out. sorry.')
|
||||||
|
logging.info('what happens when I log with logging?')
|
Loading…
x
Reference in New Issue
Block a user