mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
rename log to logger throughout
This commit is contained in:
parent
f0e07bff5a
commit
8db20e0d95
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
from ..services.default_graphs import create_system_graphs
|
from ..services.default_graphs import create_system_graphs
|
||||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
@ -49,7 +49,7 @@ class ApiDependencies:
|
|||||||
Globals.disable_xformers = not config.xformers
|
Globals.disable_xformers = not config.xformers
|
||||||
Globals.ckpt_convert = config.ckpt_convert
|
Globals.ckpt_convert = config.ckpt_convert
|
||||||
|
|
||||||
log.info(f"Internet connectivity is {Globals.internet_available}")
|
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import shutil
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from fastapi.routing import APIRouter, HTTPException
|
from fastapi.routing import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -116,16 +116,16 @@ async def delete_model(model_name: str) -> None:
|
|||||||
model_exists = model_name in model_names
|
model_exists = model_name in model_names
|
||||||
|
|
||||||
# check if model exists
|
# check if model exists
|
||||||
log.info(f"Checking for model {model_name}...")
|
logger.info(f"Checking for model {model_name}...")
|
||||||
|
|
||||||
if model_exists:
|
if model_exists:
|
||||||
log.info(f"Deleting Model: {model_name}")
|
logger.info(f"Deleting Model: {model_name}")
|
||||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||||
log.info(f"Model Deleted: {model_name}")
|
logger.info(f"Model Deleted: {model_name}")
|
||||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
log.error(f"Model not found")
|
logger.error(f"Model not found")
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
|
|||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from ..invocations.baseinvocation import BaseInvocation
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
from ..invocations.image import ImageField
|
from ..invocations.image import ImageField
|
||||||
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
||||||
@ -230,7 +230,7 @@ class HistoryCommand(BaseCommand):
|
|||||||
for i in range(min(self.count, len(history))):
|
for i in range(min(self.count, len(history))):
|
||||||
entry_id = history[-1 - i]
|
entry_id = history[-1 - i]
|
||||||
entry = context.get_session().graph.get_node(entry_id)
|
entry = context.get_session().graph.get_node(entry_id)
|
||||||
log.info(f"{entry_id}: {get_invocation_command(entry)}")
|
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
|
||||||
|
|
||||||
|
|
||||||
class SetDefaultCommand(BaseCommand):
|
class SetDefaultCommand(BaseCommand):
|
||||||
|
@ -10,7 +10,7 @@ import shlex
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from ...backend import ModelManager, Globals
|
from ...backend import ModelManager, Globals
|
||||||
from ..invocations.baseinvocation import BaseInvocation
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
from .commands import BaseCommand
|
from .commands import BaseCommand
|
||||||
@ -161,7 +161,7 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
|
|||||||
pass
|
pass
|
||||||
except OSError: # file likely corrupted
|
except OSError: # file likely corrupted
|
||||||
newname = f"{histfile}.old"
|
newname = f"{histfile}.old"
|
||||||
log.error(
|
logger.error(
|
||||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||||
)
|
)
|
||||||
histfile.replace(Path(newname))
|
histfile.replace(Path(newname))
|
||||||
|
@ -14,7 +14,7 @@ from pydantic import BaseModel
|
|||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.metadata import PngMetadataService
|
from invokeai.app.services.metadata import PngMetadataService
|
||||||
from .services.default_graphs import create_system_graphs
|
from .services.default_graphs import create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
@ -181,7 +181,7 @@ def invoke_all(context: CliContext):
|
|||||||
# Print any errors
|
# Print any errors
|
||||||
if context.session.has_error():
|
if context.session.has_error():
|
||||||
for n in context.session.errors:
|
for n in context.session.errors:
|
||||||
log.error(
|
logger.error(
|
||||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -364,12 +364,12 @@ def invoke_cli():
|
|||||||
invoke_all(context)
|
invoke_all(context)
|
||||||
|
|
||||||
except InvalidArgs:
|
except InvalidArgs:
|
||||||
log.warning('Invalid command, use "help" to list commands')
|
logger.warning('Invalid command, use "help" to list commands')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except SessionError:
|
except SessionError:
|
||||||
# Start a new session
|
# Start a new session
|
||||||
log.warning("Session error: creating a new session")
|
logger.warning("Session error: creating a new session")
|
||||||
context.reset()
|
context.reset()
|
||||||
|
|
||||||
except ExitCli:
|
except ExitCli:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.model_management.model_manager import ModelManager
|
from invokeai.backend.model_management.model_manager import ModelManager
|
||||||
|
|
||||||
|
|
||||||
@ -8,6 +8,6 @@ def choose_model(model_manager: ModelManager, model_name: str):
|
|||||||
model = model_manager.get_model(model_name)
|
model = model_manager.get_model(model_name)
|
||||||
else:
|
else:
|
||||||
model = model_manager.get_model()
|
model = model_manager.get_model()
|
||||||
log.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
|
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -7,7 +7,7 @@ from omegaconf import OmegaConf
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import invokeai.version
|
import invokeai.version
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from ...backend import ModelManager
|
from ...backend import ModelManager
|
||||||
from ...backend.util import choose_precision, choose_torch_device
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
@ -21,8 +21,8 @@ def get_model_manager(config: Args) -> ModelManager:
|
|||||||
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||||
log.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
# these two lines prevent a horrible warning message from appearing
|
||||||
# when the frozen CLIP tokenizer is imported
|
# when the frozen CLIP tokenizer is imported
|
||||||
@ -67,7 +67,7 @@ def get_model_manager(config: Args) -> ModelManager:
|
|||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
report_model_error(config, e)
|
report_model_error(config, e)
|
||||||
except (IOError, KeyError) as e:
|
except (IOError, KeyError) as e:
|
||||||
log.error(f"{e}. Aborting.")
|
logger.error(f"{e}. Aborting.")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
# try to autoconvert new models
|
# try to autoconvert new models
|
||||||
@ -81,13 +81,13 @@ def get_model_manager(config: Args) -> ModelManager:
|
|||||||
return model_manager
|
return model_manager
|
||||||
|
|
||||||
def report_model_error(opt: Namespace, e: Exception):
|
def report_model_error(opt: Namespace, e: Exception):
|
||||||
log.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||||
log.error(
|
logger.error(
|
||||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||||
)
|
)
|
||||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||||
if yes_to_all:
|
if yes_to_all:
|
||||||
log.warning
|
logger.warning
|
||||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -97,7 +97,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
|||||||
if response.startswith(("n", "N")):
|
if response.startswith(("n", "N")):
|
||||||
return
|
return
|
||||||
|
|
||||||
log.info("invokeai-configure is launching....\n")
|
logger.info("invokeai-configure is launching....\n")
|
||||||
|
|
||||||
# Match arguments that were set on the CLI
|
# Match arguments that were set on the CLI
|
||||||
# only the arguments accepted by the configuration script are parsed
|
# only the arguments accepted by the configuration script are parsed
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import torch
|
import torch
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from ...backend.restoration import Restoration
|
from ...backend.restoration import Restoration
|
||||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||||
|
|
||||||
@ -21,16 +21,16 @@ class RestorationServices:
|
|||||||
args.gfpgan_model_path
|
args.gfpgan_model_path
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.info("Face restoration disabled")
|
logger.info("Face restoration disabled")
|
||||||
if args.esrgan:
|
if args.esrgan:
|
||||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||||
else:
|
else:
|
||||||
log.info("Upscaling disabled")
|
logger.info("Upscaling disabled")
|
||||||
else:
|
else:
|
||||||
log.info("Face restoration and upscaling disabled")
|
logger.info("Face restoration and upscaling disabled")
|
||||||
except (ModuleNotFoundError, ImportError):
|
except (ModuleNotFoundError, ImportError):
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
log.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
||||||
self.device = torch.device(choose_torch_device())
|
self.device = torch.device(choose_torch_device())
|
||||||
self.gfpgan = gfpgan
|
self.gfpgan = gfpgan
|
||||||
self.codeformer = codeformer
|
self.codeformer = codeformer
|
||||||
@ -59,14 +59,14 @@ class RestorationServices:
|
|||||||
if self.gfpgan is not None or self.codeformer is not None:
|
if self.gfpgan is not None or self.codeformer is not None:
|
||||||
if facetool == "gfpgan":
|
if facetool == "gfpgan":
|
||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
log.info(
|
logger.info(
|
||||||
"GFPGAN not found. Face restoration is disabled."
|
"GFPGAN not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = self.gfpgan.process(image, strength, seed)
|
image = self.gfpgan.process(image, strength, seed)
|
||||||
if facetool == "codeformer":
|
if facetool == "codeformer":
|
||||||
if self.codeformer is None:
|
if self.codeformer is None:
|
||||||
log.info(
|
logger.info(
|
||||||
"CodeFormer not found. Face restoration is disabled."
|
"CodeFormer not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -81,7 +81,7 @@ class RestorationServices:
|
|||||||
fidelity=codeformer_fidelity,
|
fidelity=codeformer_fidelity,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.info("Face Restoration is disabled.")
|
logger.info("Face Restoration is disabled.")
|
||||||
if upscale is not None:
|
if upscale is not None:
|
||||||
if self.esrgan is not None:
|
if self.esrgan is not None:
|
||||||
if len(upscale) < 2:
|
if len(upscale) < 2:
|
||||||
@ -94,9 +94,9 @@ class RestorationServices:
|
|||||||
denoise_str=upscale_denoise_str,
|
denoise_str=upscale_denoise_str,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.info("ESRGAN is disabled. Image not upscaled.")
|
logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.info(
|
logger.info(
|
||||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ from pathlib import Path
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import invokeai.version
|
import invokeai.version
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.image_util import retrieve_metadata
|
from invokeai.backend.image_util import retrieve_metadata
|
||||||
|
|
||||||
from .globals import Globals
|
from .globals import Globals
|
||||||
@ -190,7 +190,7 @@ class Args(object):
|
|||||||
print(f"{APP_NAME} {APP_VERSION}")
|
print(f"{APP_NAME} {APP_VERSION}")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
log.info("Initializing, be patient...")
|
logger.info("Initializing, be patient...")
|
||||||
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
|
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
|
||||||
Globals.try_patchmatch = switches.patchmatch
|
Globals.try_patchmatch = switches.patchmatch
|
||||||
|
|
||||||
@ -198,12 +198,12 @@ class Args(object):
|
|||||||
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
|
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
|
||||||
legacyinit = os.path.expanduser("~/.invokeai")
|
legacyinit = os.path.expanduser("~/.invokeai")
|
||||||
if os.path.exists(initfile):
|
if os.path.exists(initfile):
|
||||||
log.info(
|
logger.info(
|
||||||
f"Initialization file {initfile} found. Loading...",
|
f"Initialization file {initfile} found. Loading...",
|
||||||
)
|
)
|
||||||
sysargs.insert(0, f"@{initfile}")
|
sysargs.insert(0, f"@{initfile}")
|
||||||
elif os.path.exists(legacyinit):
|
elif os.path.exists(legacyinit):
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
|
f"Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
|
||||||
)
|
)
|
||||||
sysargs.insert(0, f"@{legacyinit}")
|
sysargs.insert(0, f"@{legacyinit}")
|
||||||
@ -214,7 +214,7 @@ class Args(object):
|
|||||||
self._arg_switches = self._arg_parser.parse_args(sysargs)
|
self._arg_switches = self._arg_parser.parse_args(sysargs)
|
||||||
return self._arg_switches
|
return self._arg_switches
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"An exception has occurred: {e}")
|
logger.error(f"An exception has occurred: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def parse_cmd(self, cmd_string):
|
def parse_cmd(self, cmd_string):
|
||||||
@ -1154,7 +1154,7 @@ class Args(object):
|
|||||||
|
|
||||||
|
|
||||||
def format_metadata(**kwargs):
|
def format_metadata(**kwargs):
|
||||||
log.warning("format_metadata() is deprecated. Please use metadata_dumps()")
|
logger.warning("format_metadata() is deprecated. Please use metadata_dumps()")
|
||||||
return metadata_dumps(kwargs)
|
return metadata_dumps(kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -1326,7 +1326,7 @@ def metadata_loads(metadata) -> list:
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
log.error("Could not read metadata")
|
logger.error("Could not read metadata")
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from .args import metadata_from_png
|
from .args import metadata_from_png
|
||||||
from .generator import infill_methods
|
from .generator import infill_methods
|
||||||
from .globals import Globals, global_cache_dir
|
from .globals import Globals, global_cache_dir
|
||||||
@ -196,12 +196,12 @@ class Generate:
|
|||||||
# device to Generate(). However the device was then ignored, so
|
# device to Generate(). However the device was then ignored, so
|
||||||
# it wasn't actually doing anything. This logic could be reinstated.
|
# it wasn't actually doing anything. This logic could be reinstated.
|
||||||
self.device = torch.device(choose_torch_device())
|
self.device = torch.device(choose_torch_device())
|
||||||
log.info(f"Using device_type {self.device.type}")
|
logger.info(f"Using device_type {self.device.type}")
|
||||||
if full_precision:
|
if full_precision:
|
||||||
if self.precision != "auto":
|
if self.precision != "auto":
|
||||||
raise ValueError("Remove --full_precision / -F if using --precision")
|
raise ValueError("Remove --full_precision / -F if using --precision")
|
||||||
log.warning("Please remove deprecated --full_precision / -F")
|
logger.warning("Please remove deprecated --full_precision / -F")
|
||||||
log.warning("If auto config does not work you can use --precision=float32")
|
logger.warning("If auto config does not work you can use --precision=float32")
|
||||||
self.precision = "float32"
|
self.precision = "float32"
|
||||||
if self.precision == "auto":
|
if self.precision == "auto":
|
||||||
self.precision = choose_precision(self.device)
|
self.precision = choose_precision(self.device)
|
||||||
@ -209,13 +209,13 @@ class Generate:
|
|||||||
|
|
||||||
if is_xformers_available():
|
if is_xformers_available():
|
||||||
if torch.cuda.is_available() and not Globals.disable_xformers:
|
if torch.cuda.is_available() and not Globals.disable_xformers:
|
||||||
log.info("xformers memory-efficient attention is available and enabled")
|
logger.info("xformers memory-efficient attention is available and enabled")
|
||||||
else:
|
else:
|
||||||
log.info(
|
logger.info(
|
||||||
"xformers memory-efficient attention is available but disabled"
|
"xformers memory-efficient attention is available but disabled"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.info("xformers not installed")
|
logger.info("xformers not installed")
|
||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_manager = ModelManager(
|
self.model_manager = ModelManager(
|
||||||
@ -230,7 +230,7 @@ class Generate:
|
|||||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||||
model = model or fallback
|
model = model or fallback
|
||||||
if not self.model_manager.valid_model(model):
|
if not self.model_manager.valid_model(model):
|
||||||
log.warning(
|
logger.warning(
|
||||||
f'"{model}" is not a known model name; falling back to {fallback}.'
|
f'"{model}" is not a known model name; falling back to {fallback}.'
|
||||||
)
|
)
|
||||||
model = None
|
model = None
|
||||||
@ -247,10 +247,10 @@ class Generate:
|
|||||||
|
|
||||||
# load safety checker if requested
|
# load safety checker if requested
|
||||||
if safety_checker:
|
if safety_checker:
|
||||||
log.info("Initializing NSFW checker")
|
logger.info("Initializing NSFW checker")
|
||||||
self.safety_checker = SafetyChecker(self.device)
|
self.safety_checker = SafetyChecker(self.device)
|
||||||
else:
|
else:
|
||||||
log.info("NSFW checker is disabled")
|
logger.info("NSFW checker is disabled")
|
||||||
|
|
||||||
def prompt2png(self, prompt, outdir, **kwargs):
|
def prompt2png(self, prompt, outdir, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -568,7 +568,7 @@ class Generate:
|
|||||||
self.clear_cuda_cache()
|
self.clear_cuda_cache()
|
||||||
|
|
||||||
if catch_interrupts:
|
if catch_interrupts:
|
||||||
log.warning("Interrupted** Partial results will be returned.")
|
logger.warning("Interrupted** Partial results will be returned.")
|
||||||
else:
|
else:
|
||||||
raise KeyboardInterrupt
|
raise KeyboardInterrupt
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@ -576,11 +576,11 @@ class Generate:
|
|||||||
self.clear_cuda_cache()
|
self.clear_cuda_cache()
|
||||||
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
log.info("Could not generate image.")
|
logger.info("Could not generate image.")
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
log.info("Usage stats:")
|
logger.info("Usage stats:")
|
||||||
log.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic))
|
logger.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic))
|
||||||
self.print_cuda_stats()
|
self.print_cuda_stats()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -610,14 +610,14 @@ class Generate:
|
|||||||
def print_cuda_stats(self):
|
def print_cuda_stats(self):
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
self.gather_cuda_stats()
|
self.gather_cuda_stats()
|
||||||
log.info(
|
logger.info(
|
||||||
"Max VRAM used for this generation: "+
|
"Max VRAM used for this generation: "+
|
||||||
"%4.2fG. " % (self.max_memory_allocated / 1e9)+
|
"%4.2fG. " % (self.max_memory_allocated / 1e9)+
|
||||||
"Current VRAM utilization: "+
|
"Current VRAM utilization: "+
|
||||||
"%4.2fG" % (self.memory_allocated / 1e9)
|
"%4.2fG" % (self.memory_allocated / 1e9)
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(
|
logger.info(
|
||||||
"Max VRAM used since script start: " +
|
"Max VRAM used since script start: " +
|
||||||
"%4.2fG" % (self.session_peakmem / 1e9)
|
"%4.2fG" % (self.session_peakmem / 1e9)
|
||||||
)
|
)
|
||||||
@ -648,7 +648,7 @@ class Generate:
|
|||||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
|
|
||||||
prompt = opt.prompt or args.prompt or ""
|
prompt = opt.prompt or args.prompt or ""
|
||||||
log.info(f'using seed {seed} and prompt "{prompt}" for {image_path}')
|
logger.info(f'using seed {seed} and prompt "{prompt}" for {image_path}')
|
||||||
|
|
||||||
# try to reuse the same filename prefix as the original file.
|
# try to reuse the same filename prefix as the original file.
|
||||||
# we take everything up to the first period
|
# we take everything up to the first period
|
||||||
@ -697,7 +697,7 @@ class Generate:
|
|||||||
try:
|
try:
|
||||||
extend_instructions[direction] = int(pixels)
|
extend_instructions[direction] = int(pixels)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
log.warning(
|
logger.warning(
|
||||||
'invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
'invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -721,7 +721,7 @@ class Generate:
|
|||||||
# fetch the metadata from the image
|
# fetch the metadata from the image
|
||||||
generator = self.select_generator(embiggen=True)
|
generator = self.select_generator(embiggen=True)
|
||||||
opt.strength = opt.embiggen_strength or 0.40
|
opt.strength = opt.embiggen_strength or 0.40
|
||||||
log.info(
|
logger.info(
|
||||||
f"Setting img2img strength to {opt.strength} for happy embiggening"
|
f"Setting img2img strength to {opt.strength} for happy embiggening"
|
||||||
)
|
)
|
||||||
generator.generate(
|
generator.generate(
|
||||||
@ -749,12 +749,12 @@ class Generate:
|
|||||||
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
|
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
|
||||||
|
|
||||||
elif tool is None:
|
elif tool is None:
|
||||||
log.warning(
|
logger.warning(
|
||||||
"please provide at least one postprocessing option, such as -G or -U"
|
"please provide at least one postprocessing option, such as -G or -U"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
log.warning(f"postprocessing tool {tool} is not yet supported")
|
logger.warning(f"postprocessing tool {tool} is not yet supported")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def select_generator(
|
def select_generator(
|
||||||
@ -798,7 +798,7 @@ class Generate:
|
|||||||
image = self._load_img(img)
|
image = self._load_img(img)
|
||||||
|
|
||||||
if image.width < self.width and image.height < self.height:
|
if image.width < self.width and image.height < self.height:
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
f"img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -810,7 +810,7 @@ class Generate:
|
|||||||
if (image.width * image.height) > (
|
if (image.width * image.height) > (
|
||||||
self.width * self.height
|
self.width * self.height
|
||||||
) and self.size_matters:
|
) and self.size_matters:
|
||||||
log.info(
|
logger.info(
|
||||||
"This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
"This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
||||||
)
|
)
|
||||||
self.size_matters = False
|
self.size_matters = False
|
||||||
@ -892,11 +892,11 @@ class Generate:
|
|||||||
try:
|
try:
|
||||||
model_data = cache.get_model(model_name)
|
model_data = cache.get_model(model_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"model {model_name} could not be loaded: {str(e)}")
|
logger.warning(f"model {model_name} could not be loaded: {str(e)}")
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
if previous_model_name is None:
|
if previous_model_name is None:
|
||||||
raise e
|
raise e
|
||||||
log.warning("trying to reload previous model")
|
logger.warning("trying to reload previous model")
|
||||||
model_data = cache.get_model(previous_model_name) # load previous
|
model_data = cache.get_model(previous_model_name) # load previous
|
||||||
if model_data is None:
|
if model_data is None:
|
||||||
raise e
|
raise e
|
||||||
@ -963,14 +963,14 @@ class Generate:
|
|||||||
if self.gfpgan is not None or self.codeformer is not None:
|
if self.gfpgan is not None or self.codeformer is not None:
|
||||||
if facetool == "gfpgan":
|
if facetool == "gfpgan":
|
||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
log.info(
|
logger.info(
|
||||||
"GFPGAN not found. Face restoration is disabled."
|
"GFPGAN not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = self.gfpgan.process(image, strength, seed)
|
image = self.gfpgan.process(image, strength, seed)
|
||||||
if facetool == "codeformer":
|
if facetool == "codeformer":
|
||||||
if self.codeformer is None:
|
if self.codeformer is None:
|
||||||
log.info(
|
logger.info(
|
||||||
"CodeFormer not found. Face restoration is disabled."
|
"CodeFormer not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -985,7 +985,7 @@ class Generate:
|
|||||||
fidelity=codeformer_fidelity,
|
fidelity=codeformer_fidelity,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.info("Face Restoration is disabled.")
|
logger.info("Face Restoration is disabled.")
|
||||||
if upscale is not None:
|
if upscale is not None:
|
||||||
if self.esrgan is not None:
|
if self.esrgan is not None:
|
||||||
if len(upscale) < 2:
|
if len(upscale) < 2:
|
||||||
@ -998,9 +998,9 @@ class Generate:
|
|||||||
denoise_str=upscale_denoise_str,
|
denoise_str=upscale_denoise_str,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.info("ESRGAN is disabled. Image not upscaled.")
|
logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.info(
|
logger.info(
|
||||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1077,7 +1077,7 @@ class Generate:
|
|||||||
)
|
)
|
||||||
self.sampler = default
|
self.sampler = default
|
||||||
|
|
||||||
log.info(msg)
|
logger.info(msg)
|
||||||
|
|
||||||
if not hasattr(self.sampler, "uses_inpainting_model"):
|
if not hasattr(self.sampler, "uses_inpainting_model"):
|
||||||
# FIXME: terrible kludge!
|
# FIXME: terrible kludge!
|
||||||
@ -1086,17 +1086,17 @@ class Generate:
|
|||||||
def _load_img(self, img) -> Image:
|
def _load_img(self, img) -> Image:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
image = img
|
image = img
|
||||||
log.info(f"using provided input image of size {image.width}x{image.height}")
|
logger.info(f"using provided input image of size {image.width}x{image.height}")
|
||||||
elif isinstance(img, str):
|
elif isinstance(img, str):
|
||||||
assert os.path.exists(img), f"{img}: File not found"
|
assert os.path.exists(img), f"{img}: File not found"
|
||||||
|
|
||||||
image = Image.open(img)
|
image = Image.open(img)
|
||||||
log.info(
|
logger.info(
|
||||||
f"loaded input image of size {image.width}x{image.height} from {img}"
|
f"loaded input image of size {image.width}x{image.height} from {img}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = Image.open(img)
|
image = Image.open(img)
|
||||||
log.info(f"loaded input image of size {image.width}x{image.height}")
|
logger.info(f"loaded input image of size {image.width}x{image.height}")
|
||||||
image = ImageOps.exif_transpose(image)
|
image = ImageOps.exif_transpose(image)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@ -1184,11 +1184,11 @@ class Generate:
|
|||||||
|
|
||||||
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
|
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
|
||||||
if not mask:
|
if not mask:
|
||||||
log.info(
|
logger.info(
|
||||||
"Initial image has transparent areas. Will inpaint in these regions."
|
"Initial image has transparent areas. Will inpaint in these regions."
|
||||||
)
|
)
|
||||||
if (not force_outpaint) and self._check_for_erasure(image):
|
if (not force_outpaint) and self._check_for_erasure(image):
|
||||||
log.info(
|
logger.info(
|
||||||
"Colors underneath the transparent region seem to have been erased.\n" +
|
"Colors underneath the transparent region seem to have been erased.\n" +
|
||||||
"Inpainting will be suboptimal. Please preserve the colors when making\n" +
|
"Inpainting will be suboptimal. Please preserve the colors when making\n" +
|
||||||
"a transparency mask, or provide mask explicitly using --init_mask (-M)."
|
"a transparency mask, or provide mask explicitly using --init_mask (-M)."
|
||||||
@ -1202,10 +1202,10 @@ class Generate:
|
|||||||
|
|
||||||
def _fit_image(self, image, max_dimensions):
|
def _fit_image(self, image, max_dimensions):
|
||||||
w, h = max_dimensions
|
w, h = max_dimensions
|
||||||
log.info(f"image will be resized to fit inside a box {w}x{h} in size.")
|
logger.info(f"image will be resized to fit inside a box {w}x{h} in size.")
|
||||||
# note that InitImageResizer does the multiple of 64 truncation internally
|
# note that InitImageResizer does the multiple of 64 truncation internally
|
||||||
image = InitImageResizer(image).resize(width=w, height=h)
|
image = InitImageResizer(image).resize(width=w, height=h)
|
||||||
log.info(
|
logger.info(
|
||||||
f"after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
f"after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
||||||
)
|
)
|
||||||
return image
|
return image
|
||||||
@ -1217,7 +1217,7 @@ class Generate:
|
|||||||
) # resize to integer multiple of 64
|
) # resize to integer multiple of 64
|
||||||
if h != height or w != width:
|
if h != height or w != width:
|
||||||
if log:
|
if log:
|
||||||
log.info(
|
logger.info(
|
||||||
f"Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
f"Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
||||||
)
|
)
|
||||||
height = h
|
height = h
|
||||||
|
@ -25,7 +25,7 @@ from typing import Callable, List, Iterator, Optional, Type
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from ..image_util import configure_model_padding
|
from ..image_util import configure_model_padding
|
||||||
from ..util.util import rand_perlin_2d
|
from ..util.util import rand_perlin_2d
|
||||||
from ..safety_checker import SafetyChecker
|
from ..safety_checker import SafetyChecker
|
||||||
@ -373,7 +373,7 @@ class Generator:
|
|||||||
try:
|
try:
|
||||||
x_T = self.get_noise(width, height)
|
x_T = self.get_noise(width, height)
|
||||||
except:
|
except:
|
||||||
log.error("An error occurred while getting initial noise")
|
logger.error("An error occurred while getting initial noise")
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||||
@ -608,7 +608,7 @@ class Generator:
|
|||||||
image = self.sample_to_image(sample)
|
image = self.sample_to_image(sample)
|
||||||
dirname = os.path.dirname(filepath) or "."
|
dirname = os.path.dirname(filepath) or "."
|
||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
log.info(f"creating directory {dirname}")
|
logger.info(f"creating directory {dirname}")
|
||||||
os.makedirs(dirname, exist_ok=True)
|
os.makedirs(dirname, exist_ok=True)
|
||||||
image.save(filepath, "PNG")
|
image.save(filepath, "PNG")
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
from .base import Generator
|
from .base import Generator
|
||||||
from .img2img import Img2Img
|
from .img2img import Img2Img
|
||||||
@ -73,21 +73,21 @@ class Embiggen(Generator):
|
|||||||
embiggen = [1.0] # If not specified, assume no scaling
|
embiggen = [1.0] # If not specified, assume no scaling
|
||||||
elif embiggen[0] < 0:
|
elif embiggen[0] < 0:
|
||||||
embiggen[0] = 1.0
|
embiggen[0] = 1.0
|
||||||
log.warning(
|
logger.warning(
|
||||||
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||||
)
|
)
|
||||||
if len(embiggen) < 2:
|
if len(embiggen) < 2:
|
||||||
embiggen.append(0.75)
|
embiggen.append(0.75)
|
||||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||||
embiggen[1] = 0.75
|
embiggen[1] = 0.75
|
||||||
log.warning(
|
logger.warning(
|
||||||
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||||
)
|
)
|
||||||
if len(embiggen) < 3:
|
if len(embiggen) < 3:
|
||||||
embiggen.append(0.25)
|
embiggen.append(0.25)
|
||||||
elif embiggen[2] < 0:
|
elif embiggen[2] < 0:
|
||||||
embiggen[2] = 0.25
|
embiggen[2] = 0.25
|
||||||
log.warning(
|
logger.warning(
|
||||||
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ class Embiggen(Generator):
|
|||||||
embiggen_tiles.sort()
|
embiggen_tiles.sort()
|
||||||
|
|
||||||
if strength >= 0.5:
|
if strength >= 0.5:
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ class Embiggen(Generator):
|
|||||||
from ..restoration.realesrgan import ESRGAN
|
from ..restoration.realesrgan import ESRGAN
|
||||||
|
|
||||||
esrgan = ESRGAN()
|
esrgan = ESRGAN()
|
||||||
log.info(
|
logger.info(
|
||||||
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||||
)
|
)
|
||||||
if embiggen[0] > 2:
|
if embiggen[0] > 2:
|
||||||
@ -313,9 +313,9 @@ class Embiggen(Generator):
|
|||||||
def make_image():
|
def make_image():
|
||||||
# Make main tiles -------------------------------------------------
|
# Make main tiles -------------------------------------------------
|
||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
log.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||||
else:
|
else:
|
||||||
log.info(
|
logger.info(
|
||||||
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -362,11 +362,11 @@ class Embiggen(Generator):
|
|||||||
# newinitimage.save(newinitimagepath)
|
# newinitimage.save(newinitimagepath)
|
||||||
|
|
||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
log.debug(
|
logger.debug(
|
||||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||||
|
|
||||||
# create a torch tensor from an Image
|
# create a torch tensor from an Image
|
||||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||||
@ -548,7 +548,7 @@ class Embiggen(Generator):
|
|||||||
# Layer tile onto final image
|
# Layer tile onto final image
|
||||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||||
else:
|
else:
|
||||||
log.error(
|
logger.error(
|
||||||
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeli
|
|||||||
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
||||||
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
class Txt2Img2Img(Generator):
|
class Txt2Img2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
@ -79,7 +79,7 @@ class Txt2Img2Img(Generator):
|
|||||||
# the message below is accurate.
|
# the message below is accurate.
|
||||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||||
log.info(
|
logger.info(
|
||||||
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ wraps the actual patchmatch object. It respects the global
|
|||||||
be suppressed or deferred
|
be suppressed or deferred
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
class PatchMatch:
|
class PatchMatch:
|
||||||
@ -27,12 +27,12 @@ class PatchMatch:
|
|||||||
from patchmatch import patch_match as pm
|
from patchmatch import patch_match as pm
|
||||||
|
|
||||||
if pm.patchmatch_available:
|
if pm.patchmatch_available:
|
||||||
log.info("Patchmatch initialized")
|
logger.info("Patchmatch initialized")
|
||||||
else:
|
else:
|
||||||
log.info("Patchmatch not loaded (nonfatal)")
|
logger.info("Patchmatch not loaded (nonfatal)")
|
||||||
self.patch_match = pm
|
self.patch_match = pm
|
||||||
else:
|
else:
|
||||||
log.info("Patchmatch loading disabled")
|
logger.info("Patchmatch loading disabled")
|
||||||
self.tried_load = True
|
self.tried_load = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -32,7 +32,7 @@ import torch
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import global_cache_dir
|
from invokeai.backend.globals import global_cache_dir
|
||||||
|
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||||
@ -83,7 +83,7 @@ class Txt2Mask(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device="cpu", refined=False):
|
def __init__(self, device="cpu", refined=False):
|
||||||
log.info("Initializing clipseg model for text to mask inference")
|
logger.info("Initializing clipseg model for text to mask inference")
|
||||||
|
|
||||||
# BUG: we are not doing anything with the device option at this time
|
# BUG: we are not doing anything with the device option at this time
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -25,7 +25,7 @@ from typing import Union
|
|||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
||||||
|
|
||||||
from .model_manager import ModelManager, SDLegacyType
|
from .model_manager import ModelManager, SDLegacyType
|
||||||
@ -373,9 +373,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
unet_key = "model.diffusion_model."
|
unet_key = "model.diffusion_model."
|
||||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||||
log.debug(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
logger.debug(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||||
if extract_ema:
|
if extract_ema:
|
||||||
log.debug("Extracting EMA weights (usually better for inference)")
|
logger.debug("Extracting EMA weights (usually better for inference)")
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("model.diffusion_model"):
|
if key.startswith("model.diffusion_model"):
|
||||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||||
@ -393,7 +393,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
key
|
key
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.debug(
|
logger.debug(
|
||||||
"Extracting only the non-EMA weights (usually better for fine-tuning)"
|
"Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1116,7 +1116,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
if "global_step" in checkpoint:
|
if "global_step" in checkpoint:
|
||||||
global_step = checkpoint["global_step"]
|
global_step = checkpoint["global_step"]
|
||||||
else:
|
else:
|
||||||
log.debug("global_step key not found in model")
|
logger.debug("global_step key not found in model")
|
||||||
global_step = None
|
global_step = None
|
||||||
|
|
||||||
# sometimes there is a state_dict key and sometimes not
|
# sometimes there is a state_dict key and sometimes not
|
||||||
@ -1230,15 +1230,15 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
# If a replacement VAE path was specified, we'll incorporate that into
|
# If a replacement VAE path was specified, we'll incorporate that into
|
||||||
# the checkpoint model and then convert it
|
# the checkpoint model and then convert it
|
||||||
if vae_path:
|
if vae_path:
|
||||||
log.debug(f"Converting VAE {vae_path}")
|
logger.debug(f"Converting VAE {vae_path}")
|
||||||
replace_checkpoint_vae(checkpoint,vae_path)
|
replace_checkpoint_vae(checkpoint,vae_path)
|
||||||
# otherwise we use the original VAE, provided that
|
# otherwise we use the original VAE, provided that
|
||||||
# an externally loaded diffusers VAE was not passed
|
# an externally loaded diffusers VAE was not passed
|
||||||
elif not vae:
|
elif not vae:
|
||||||
log.debug("Using checkpoint model's original VAE")
|
logger.debug("Using checkpoint model's original VAE")
|
||||||
|
|
||||||
if vae:
|
if vae:
|
||||||
log.debug("Using replacement diffusers VAE")
|
logger.debug("Using replacement diffusers VAE")
|
||||||
else: # convert the original or replacement VAE
|
else: # convert the original or replacement VAE
|
||||||
vae_config = create_vae_diffusers_config(
|
vae_config = create_vae_diffusers_config(
|
||||||
original_config, image_size=image_size
|
original_config, image_size=image_size
|
||||||
|
@ -24,7 +24,7 @@ import safetensors
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
@ -133,7 +133,7 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not self.valid_model(model_name):
|
if not self.valid_model(model_name):
|
||||||
log.error(
|
logger.error(
|
||||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||||
)
|
)
|
||||||
return self.current_model
|
return self.current_model
|
||||||
@ -145,7 +145,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
if model_name in self.models:
|
if model_name in self.models:
|
||||||
requested_model = self.models[model_name]["model"]
|
requested_model = self.models[model_name]["model"]
|
||||||
log.info(f"Retrieving model {model_name} from system RAM cache")
|
logger.info(f"Retrieving model {model_name} from system RAM cache")
|
||||||
requested_model.ready()
|
requested_model.ready()
|
||||||
width = self.models[model_name]["width"]
|
width = self.models[model_name]["width"]
|
||||||
height = self.models[model_name]["height"]
|
height = self.models[model_name]["height"]
|
||||||
@ -380,7 +380,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
omega = self.config
|
omega = self.config
|
||||||
if model_name not in omega:
|
if model_name not in omega:
|
||||||
log.error(f"Unknown model {model_name}")
|
logger.error(f"Unknown model {model_name}")
|
||||||
return
|
return
|
||||||
# save these for use in deletion later
|
# save these for use in deletion later
|
||||||
conf = omega[model_name]
|
conf = omega[model_name]
|
||||||
@ -393,13 +393,13 @@ class ModelManager(object):
|
|||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
if delete_files:
|
if delete_files:
|
||||||
if weights:
|
if weights:
|
||||||
log.info(f"Deleting file {weights}")
|
logger.info(f"Deleting file {weights}")
|
||||||
Path(weights).unlink(missing_ok=True)
|
Path(weights).unlink(missing_ok=True)
|
||||||
elif path:
|
elif path:
|
||||||
log.info(f"Deleting directory {path}")
|
logger.info(f"Deleting directory {path}")
|
||||||
rmtree(path, ignore_errors=True)
|
rmtree(path, ignore_errors=True)
|
||||||
elif repo_id:
|
elif repo_id:
|
||||||
log.info(f"Deleting the cached model directory for {repo_id}")
|
logger.info(f"Deleting the cached model directory for {repo_id}")
|
||||||
self._delete_model_from_cache(repo_id)
|
self._delete_model_from_cache(repo_id)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
@ -440,7 +440,7 @@ class ModelManager(object):
|
|||||||
def _load_model(self, model_name: str):
|
def _load_model(self, model_name: str):
|
||||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||||
if model_name not in self.config:
|
if model_name not in self.config:
|
||||||
log.error(
|
logger.error(
|
||||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -458,7 +458,7 @@ class ModelManager(object):
|
|||||||
model_format = mconfig.get("format", "ckpt")
|
model_format = mconfig.get("format", "ckpt")
|
||||||
if model_format == "ckpt":
|
if model_format == "ckpt":
|
||||||
weights = mconfig.weights
|
weights = mconfig.weights
|
||||||
log.info(f"Loading {model_name} from {weights}")
|
logger.info(f"Loading {model_name} from {weights}")
|
||||||
model, width, height, model_hash = self._load_ckpt_model(
|
model, width, height, model_hash = self._load_ckpt_model(
|
||||||
model_name, mconfig
|
model_name, mconfig
|
||||||
)
|
)
|
||||||
@ -474,13 +474,13 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# usage statistics
|
# usage statistics
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
log.info("Model loaded in " + "%4.2fs" % (toc - tic))
|
logger.info("Model loaded in " + "%4.2fs" % (toc - tic))
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
log.info(
|
logger.info(
|
||||||
"Max VRAM used to load the model: "+
|
"Max VRAM used to load the model: "+
|
||||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
|
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
|
||||||
)
|
)
|
||||||
log.info(
|
logger.info(
|
||||||
"Current VRAM usage: "+
|
"Current VRAM usage: "+
|
||||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
||||||
)
|
)
|
||||||
@ -490,11 +490,11 @@ class ModelManager(object):
|
|||||||
name_or_path = self.model_name_or_path(mconfig)
|
name_or_path = self.model_name_or_path(mconfig)
|
||||||
using_fp16 = self.precision == "float16"
|
using_fp16 = self.precision == "float16"
|
||||||
|
|
||||||
log.info(f"Loading diffusers model from {name_or_path}")
|
logger.info(f"Loading diffusers model from {name_or_path}")
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
log.debug("Using faster float16 precision")
|
logger.debug("Using faster float16 precision")
|
||||||
else:
|
else:
|
||||||
log.debug("Using more accurate float32 precision")
|
logger.debug("Using more accurate float32 precision")
|
||||||
|
|
||||||
# TODO: scan weights maybe?
|
# TODO: scan weights maybe?
|
||||||
pipeline_args: dict[str, Any] = dict(
|
pipeline_args: dict[str, Any] = dict(
|
||||||
@ -526,7 +526,7 @@ class ModelManager(object):
|
|||||||
if str(e).startswith("fp16 is not a valid"):
|
if str(e).startswith("fp16 is not a valid"):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
log.error(
|
logger.error(
|
||||||
f"An unexpected error occurred while downloading the model: {e})"
|
f"An unexpected error occurred while downloading the model: {e})"
|
||||||
)
|
)
|
||||||
if pipeline:
|
if pipeline:
|
||||||
@ -545,7 +545,7 @@ class ModelManager(object):
|
|||||||
# square images???
|
# square images???
|
||||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||||
height = width
|
height = width
|
||||||
log.debug(f"Default image dimensions = {width} x {height}")
|
logger.debug(f"Default image dimensions = {width} x {height}")
|
||||||
|
|
||||||
return pipeline, width, height, model_hash
|
return pipeline, width, height, model_hash
|
||||||
|
|
||||||
@ -562,7 +562,7 @@ class ModelManager(object):
|
|||||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||||
|
|
||||||
# Convert to diffusers and return a diffusers pipeline
|
# Convert to diffusers and return a diffusers pipeline
|
||||||
log.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||||
|
|
||||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||||
|
|
||||||
@ -627,7 +627,7 @@ class ModelManager(object):
|
|||||||
if model_name not in self.models:
|
if model_name not in self.models:
|
||||||
return
|
return
|
||||||
|
|
||||||
log.info(f"Offloading {model_name} to CPU")
|
logger.info(f"Offloading {model_name} to CPU")
|
||||||
model = self.models[model_name]["model"]
|
model = self.models[model_name]["model"]
|
||||||
model.offload_all()
|
model.offload_all()
|
||||||
self.current_model = None
|
self.current_model = None
|
||||||
@ -643,26 +643,26 @@ class ModelManager(object):
|
|||||||
and option to exit if an infected file is identified.
|
and option to exit if an infected file is identified.
|
||||||
"""
|
"""
|
||||||
# scan model
|
# scan model
|
||||||
log.debug(f"Scanning Model: {model_name}")
|
logger.debug(f"Scanning Model: {model_name}")
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
log.critical(f"Issues Found In Model: {scan_result.issues_count}")
|
logger.critical(f"Issues Found In Model: {scan_result.issues_count}")
|
||||||
log.critical("The model you are trying to load seems to be infected.")
|
logger.critical("The model you are trying to load seems to be infected.")
|
||||||
log.critical("For your safety, InvokeAI will not load this model.")
|
logger.critical("For your safety, InvokeAI will not load this model.")
|
||||||
log.critical("Please use checkpoints from trusted sources.")
|
logger.critical("Please use checkpoints from trusted sources.")
|
||||||
log.critical("Exiting InvokeAI")
|
logger.critical("Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
log.warning("InvokeAI was unable to scan the model you are using.")
|
logger.warning("InvokeAI was unable to scan the model you are using.")
|
||||||
model_safe_check_fail = ask_user(
|
model_safe_check_fail = ask_user(
|
||||||
"Do you want to to continue loading the model?", ["y", "n"]
|
"Do you want to to continue loading the model?", ["y", "n"]
|
||||||
)
|
)
|
||||||
if model_safe_check_fail.lower() != "y":
|
if model_safe_check_fail.lower() != "y":
|
||||||
log.critical("Exiting InvokeAI")
|
logger.critical("Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
log.debug("Model scanned ok")
|
logger.debug("Model scanned ok")
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
@ -779,24 +779,24 @@ class ModelManager(object):
|
|||||||
model_path: Path = None
|
model_path: Path = None
|
||||||
thing = path_url_or_repo # to save typing
|
thing = path_url_or_repo # to save typing
|
||||||
|
|
||||||
log.info(f"Probing {thing} for import")
|
logger.info(f"Probing {thing} for import")
|
||||||
|
|
||||||
if thing.startswith(("http:", "https:", "ftp:")):
|
if thing.startswith(("http:", "https:", "ftp:")):
|
||||||
log.info(f"{thing} appears to be a URL")
|
logger.info(f"{thing} appears to be a URL")
|
||||||
model_path = self._resolve_path(
|
model_path = self._resolve_path(
|
||||||
thing, "models/ldm/stable-diffusion-v1"
|
thing, "models/ldm/stable-diffusion-v1"
|
||||||
) # _resolve_path does a download if needed
|
) # _resolve_path does a download if needed
|
||||||
|
|
||||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||||
log.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
|
logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
log.debug(f"{thing} appears to be a checkpoint file on disk")
|
logger.debug(f"{thing} appears to be a checkpoint file on disk")
|
||||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||||
|
|
||||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||||
log.debug(f"{thing} appears to be a diffusers file on disk")
|
logger.debug(f"{thing} appears to be a diffusers file on disk")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing,
|
thing,
|
||||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||||
@ -807,30 +807,30 @@ class ModelManager(object):
|
|||||||
|
|
||||||
elif Path(thing).is_dir():
|
elif Path(thing).is_dir():
|
||||||
if (Path(thing) / "model_index.json").exists():
|
if (Path(thing) / "model_index.json").exists():
|
||||||
log.debug(f"{thing} appears to be a diffusers model.")
|
logger.debug(f"{thing} appears to be a diffusers model.")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing, commit_to_conf=commit_to_conf
|
thing, commit_to_conf=commit_to_conf
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.debug(f"{thing} appears to be a directory. Will scan for models to import")
|
logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
|
||||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
||||||
Path(thing).rglob("*.safetensors")
|
Path(thing).rglob("*.safetensors")
|
||||||
):
|
):
|
||||||
if model_name := self.heuristic_import(
|
if model_name := self.heuristic_import(
|
||||||
str(m), commit_to_conf=commit_to_conf
|
str(m), commit_to_conf=commit_to_conf
|
||||||
):
|
):
|
||||||
log.info(f"{model_name} successfully imported")
|
logger.info(f"{model_name} successfully imported")
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
||||||
log.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
|
logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing, commit_to_conf=commit_to_conf
|
thing, commit_to_conf=commit_to_conf
|
||||||
)
|
)
|
||||||
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
||||||
return model_name
|
return model_name
|
||||||
else:
|
else:
|
||||||
log.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
||||||
|
|
||||||
# Model_path is set in the event of a legacy checkpoint file.
|
# Model_path is set in the event of a legacy checkpoint file.
|
||||||
# If not set, we're all done
|
# If not set, we're all done
|
||||||
@ -838,7 +838,7 @@ class ModelManager(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if model_path.stem in self.config: # already imported
|
if model_path.stem in self.config: # already imported
|
||||||
log.debug("Already imported. Skipping")
|
logger.debug("Already imported. Skipping")
|
||||||
return model_path.stem
|
return model_path.stem
|
||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
@ -854,38 +854,38 @@ class ModelManager(object):
|
|||||||
# look for a like-named .yaml file in same directory
|
# look for a like-named .yaml file in same directory
|
||||||
if model_path.with_suffix(".yaml").exists():
|
if model_path.with_suffix(".yaml").exists():
|
||||||
model_config_file = model_path.with_suffix(".yaml")
|
model_config_file = model_path.with_suffix(".yaml")
|
||||||
log.debug(f"Using config file {model_config_file.name}")
|
logger.debug(f"Using config file {model_config_file.name}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model_type = self.probe_model_type(checkpoint)
|
model_type = self.probe_model_type(checkpoint)
|
||||||
if model_type == SDLegacyType.V1:
|
if model_type == SDLegacyType.V1:
|
||||||
log.debug("SD-v1 model detected")
|
logger.debug("SD-v1 model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V1_INPAINT:
|
elif model_type == SDLegacyType.V1_INPAINT:
|
||||||
log.debug("SD-v1 inpainting model detected")
|
logger.debug("SD-v1 inpainting model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root,
|
Globals.root,
|
||||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_v:
|
elif model_type == SDLegacyType.V2_v:
|
||||||
log.debug("SD-v2-v model detected")
|
logger.debug("SD-v2-v model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_e:
|
elif model_type == SDLegacyType.V2_e:
|
||||||
log.debug("SD-v2-e model detected")
|
logger.debug("SD-v2-e model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2:
|
elif model_type == SDLegacyType.V2:
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -902,7 +902,7 @@ class ModelManager(object):
|
|||||||
for suffix in ["pt", "ckpt", "safetensors"]:
|
for suffix in ["pt", "ckpt", "safetensors"]:
|
||||||
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
||||||
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
||||||
log.debug(f"Using VAE file {vae_path.name}")
|
logger.debug(f"Using VAE file {vae_path.name}")
|
||||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||||
|
|
||||||
diffuser_path = Path(
|
diffuser_path = Path(
|
||||||
@ -948,14 +948,14 @@ class ModelManager(object):
|
|||||||
from . import convert_ckpt_to_diffusers
|
from . import convert_ckpt_to_diffusers
|
||||||
|
|
||||||
if diffusers_path.exists():
|
if diffusers_path.exists():
|
||||||
log.error(
|
logger.error(
|
||||||
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
model_name = model_name or diffusers_path.name
|
model_name = model_name or diffusers_path.name
|
||||||
model_description = model_description or f"Converted version of {model_name}"
|
model_description = model_description or f"Converted version of {model_name}"
|
||||||
log.debug(f"Converting {model_name} to diffusers (30-60s)")
|
logger.debug(f"Converting {model_name} to diffusers (30-60s)")
|
||||||
try:
|
try:
|
||||||
# By passing the specified VAE to the conversion function, the autoencoder
|
# By passing the specified VAE to the conversion function, the autoencoder
|
||||||
# will be built into the model rather than tacked on afterward via the config file
|
# will be built into the model rather than tacked on afterward via the config file
|
||||||
@ -972,10 +972,10 @@ class ModelManager(object):
|
|||||||
vae_path=vae_path,
|
vae_path=vae_path,
|
||||||
scan_needed=scan_needed,
|
scan_needed=scan_needed,
|
||||||
)
|
)
|
||||||
log.debug(
|
logger.debug(
|
||||||
f"Success. Converted model is now located at {str(diffusers_path)}"
|
f"Success. Converted model is now located at {str(diffusers_path)}"
|
||||||
)
|
)
|
||||||
log.debug(f"Writing new config file entry for {model_name}")
|
logger.debug(f"Writing new config file entry for {model_name}")
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
path=str(diffusers_path),
|
path=str(diffusers_path),
|
||||||
description=model_description,
|
description=model_description,
|
||||||
@ -986,17 +986,17 @@ class ModelManager(object):
|
|||||||
self.add_model(model_name, new_config, True)
|
self.add_model(model_name, new_config, True)
|
||||||
if commit_to_conf:
|
if commit_to_conf:
|
||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
log.debug("Conversion succeeded")
|
logger.debug("Conversion succeeded")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"Conversion failed: {str(e)}")
|
logger.warning(f"Conversion failed: {str(e)}")
|
||||||
log.warning(
|
logger.warning(
|
||||||
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def search_models(self, search_folder):
|
def search_models(self, search_folder):
|
||||||
log.info(f"Finding Models In: {search_folder}")
|
logger.info(f"Finding Models In: {search_folder}")
|
||||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||||
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
||||||
|
|
||||||
@ -1020,7 +1020,7 @@ class ModelManager(object):
|
|||||||
num_loaded_models = len(self.models)
|
num_loaded_models = len(self.models)
|
||||||
if num_loaded_models >= self.max_loaded_models:
|
if num_loaded_models >= self.max_loaded_models:
|
||||||
least_recent_model = self._pop_oldest_model()
|
least_recent_model = self._pop_oldest_model()
|
||||||
log.info(
|
logger.info(
|
||||||
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
||||||
)
|
)
|
||||||
if least_recent_model is not None:
|
if least_recent_model is not None:
|
||||||
@ -1029,7 +1029,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def print_vram_usage(self) -> None:
|
def print_vram_usage(self) -> None:
|
||||||
if self._has_cuda:
|
if self._has_cuda:
|
||||||
log.info(
|
logger.info(
|
||||||
"Current VRAM usage:"+
|
"Current VRAM usage:"+
|
||||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||||
)
|
)
|
||||||
@ -1119,10 +1119,10 @@ class ModelManager(object):
|
|||||||
dest = hub / model.stem
|
dest = hub / model.stem
|
||||||
if dest.exists() and not source.exists():
|
if dest.exists() and not source.exists():
|
||||||
continue
|
continue
|
||||||
log.info(f"{source} => {dest}")
|
logger.info(f"{source} => {dest}")
|
||||||
if source.exists():
|
if source.exists():
|
||||||
if dest.is_symlink():
|
if dest.is_symlink():
|
||||||
log.warning(f"Found symlink at {dest.name}. Not migrating.")
|
logger.warning(f"Found symlink at {dest.name}. Not migrating.")
|
||||||
elif dest.exists():
|
elif dest.exists():
|
||||||
if source.is_dir():
|
if source.is_dir():
|
||||||
rmtree(source)
|
rmtree(source)
|
||||||
@ -1139,7 +1139,7 @@ class ModelManager(object):
|
|||||||
]
|
]
|
||||||
for d in empty:
|
for d in empty:
|
||||||
os.rmdir(d)
|
os.rmdir(d)
|
||||||
log.info("Migration is done. Continuing...")
|
logger.info("Migration is done. Continuing...")
|
||||||
|
|
||||||
def _resolve_path(
|
def _resolve_path(
|
||||||
self, source: Union[str, Path], dest_directory: str
|
self, source: Union[str, Path], dest_directory: str
|
||||||
@ -1182,14 +1182,14 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
log.info(f"Loading embeddings from {self.embedding_path}")
|
logger.info(f"Loading embeddings from {self.embedding_path}")
|
||||||
for root, _, files in os.walk(self.embedding_path):
|
for root, _, files in os.walk(self.embedding_path):
|
||||||
for name in files:
|
for name in files:
|
||||||
ti_path = os.path.join(root, name)
|
ti_path = os.path.join(root, name)
|
||||||
model.textual_inversion_manager.load_textual_inversion(
|
model.textual_inversion_manager.load_textual_inversion(
|
||||||
ti_path, defer_injecting_tokens=True
|
ti_path, defer_injecting_tokens=True
|
||||||
)
|
)
|
||||||
log.info(
|
logger.info(
|
||||||
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1212,7 +1212,7 @@ class ModelManager(object):
|
|||||||
with open(hashpath) as f:
|
with open(hashpath) as f:
|
||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
log.debug("Calculating sha256 hash of model files")
|
logger.debug("Calculating sha256 hash of model files")
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
count = 0
|
count = 0
|
||||||
@ -1224,7 +1224,7 @@ class ModelManager(object):
|
|||||||
sha.update(chunk)
|
sha.update(chunk)
|
||||||
hash = sha.hexdigest()
|
hash = sha.hexdigest()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
log.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||||
with open(hashpath, "w") as f:
|
with open(hashpath, "w") as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
@ -1242,13 +1242,13 @@ class ModelManager(object):
|
|||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
log.debug("Calculating sha256 hash of weights file")
|
logger.debug("Calculating sha256 hash of weights file")
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
sha.update(data)
|
sha.update(data)
|
||||||
hash = sha.hexdigest()
|
hash = sha.hexdigest()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
log.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
|
logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
|
||||||
|
|
||||||
with open(hashpath, "w") as f:
|
with open(hashpath, "w") as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
@ -1269,12 +1269,12 @@ class ModelManager(object):
|
|||||||
local_files_only=not Globals.internet_available,
|
local_files_only=not Globals.internet_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.debug(f"Loading diffusers VAE from {name_or_path}")
|
logger.debug(f"Loading diffusers VAE from {name_or_path}")
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
vae_args.update(torch_dtype=torch.float16)
|
vae_args.update(torch_dtype=torch.float16)
|
||||||
fp_args_list = [{"revision": "fp16"}, {}]
|
fp_args_list = [{"revision": "fp16"}, {}]
|
||||||
else:
|
else:
|
||||||
log.debug("Using more accurate float32 precision")
|
logger.debug("Using more accurate float32 precision")
|
||||||
fp_args_list = [{}]
|
fp_args_list = [{}]
|
||||||
|
|
||||||
vae = None
|
vae = None
|
||||||
@ -1298,7 +1298,7 @@ class ModelManager(object):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not vae and deferred_error:
|
if not vae and deferred_error:
|
||||||
log.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
|
logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||||
|
|
||||||
return vae
|
return vae
|
||||||
|
|
||||||
@ -1314,7 +1314,7 @@ class ModelManager(object):
|
|||||||
for revision in repo.revisions:
|
for revision in repo.revisions:
|
||||||
hashes_to_delete.add(revision.commit_hash)
|
hashes_to_delete.add(revision.commit_hash)
|
||||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||||
)
|
)
|
||||||
strategy.execute()
|
strategy.execute()
|
||||||
|
@ -18,7 +18,7 @@ from compel.prompt_parser import (
|
|||||||
PromptParser,
|
PromptParser,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||||
@ -163,8 +163,8 @@ def log_tokenization(
|
|||||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
):
|
):
|
||||||
log.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
|
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||||
log.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||||
|
|
||||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||||
log_tokenization_for_prompt_object(
|
log_tokenization_for_prompt_object(
|
||||||
@ -238,12 +238,12 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
|
|||||||
usedTokens += 1
|
usedTokens += 1
|
||||||
|
|
||||||
if usedTokens > 0:
|
if usedTokens > 0:
|
||||||
log.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||||
log.debug(f"{tokenized}\x1b[0m")
|
logger.debug(f"{tokenized}\x1b[0m")
|
||||||
|
|
||||||
if discarded != "":
|
if discarded != "":
|
||||||
log.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||||
log.debug(f"{discarded}\x1b[0m")
|
logger.debug(f"{discarded}\x1b[0m")
|
||||||
|
|
||||||
|
|
||||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
||||||
@ -296,7 +296,7 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
|||||||
return parsed_prompts
|
return parsed_prompts
|
||||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||||
if weight_sum == 0:
|
if weight_sum == 0:
|
||||||
log.warning(
|
logger.warning(
|
||||||
"Subprompt weights add up to zero. Discarding and using even weights instead."
|
"Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||||
)
|
)
|
||||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
class Restoration:
|
class Restoration:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@ -10,17 +10,17 @@ class Restoration:
|
|||||||
# Load GFPGAN
|
# Load GFPGAN
|
||||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
||||||
if gfpgan.gfpgan_model_exists:
|
if gfpgan.gfpgan_model_exists:
|
||||||
log.info("GFPGAN Initialized")
|
logger.info("GFPGAN Initialized")
|
||||||
else:
|
else:
|
||||||
log.info("GFPGAN Disabled")
|
logger.info("GFPGAN Disabled")
|
||||||
gfpgan = None
|
gfpgan = None
|
||||||
|
|
||||||
# Load CodeFormer
|
# Load CodeFormer
|
||||||
codeformer = self.load_codeformer()
|
codeformer = self.load_codeformer()
|
||||||
if codeformer.codeformer_model_exists:
|
if codeformer.codeformer_model_exists:
|
||||||
log.info("CodeFormer Initialized")
|
logger.info("CodeFormer Initialized")
|
||||||
else:
|
else:
|
||||||
log.info("CodeFormer Disabled")
|
logger.info("CodeFormer Disabled")
|
||||||
codeformer = None
|
codeformer = None
|
||||||
|
|
||||||
return gfpgan, codeformer
|
return gfpgan, codeformer
|
||||||
@ -41,5 +41,5 @@ class Restoration:
|
|||||||
from .realesrgan import ESRGAN
|
from .realesrgan import ESRGAN
|
||||||
|
|
||||||
esrgan = ESRGAN(esrgan_bg_tile)
|
esrgan = ESRGAN(esrgan_bg_tile)
|
||||||
log.info("ESRGAN Initialized")
|
logger.info("ESRGAN Initialized")
|
||||||
return esrgan
|
return esrgan
|
||||||
|
@ -5,7 +5,7 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from ..globals import Globals
|
from ..globals import Globals
|
||||||
|
|
||||||
pretrained_model_url = (
|
pretrained_model_url = (
|
||||||
@ -24,12 +24,12 @@ class CodeFormerRestoration:
|
|||||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
if not self.codeformer_model_exists:
|
if not self.codeformer_model_exists:
|
||||||
log.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||||
sys.path.append(os.path.abspath(codeformer_dir))
|
sys.path.append(os.path.abspath(codeformer_dir))
|
||||||
|
|
||||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
log.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
|
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
@ -98,7 +98,7 @@ class CodeFormerRestoration:
|
|||||||
del output
|
del output
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except RuntimeError as error:
|
except RuntimeError as error:
|
||||||
log.error(f"Failed inference for CodeFormer: {error}.")
|
logger.error(f"Failed inference for CodeFormer: {error}.")
|
||||||
restored_face = cropped_face
|
restored_face = cropped_face
|
||||||
|
|
||||||
restored_face = restored_face.astype("uint8")
|
restored_face = restored_face.astype("uint8")
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
class GFPGAN:
|
class GFPGAN:
|
||||||
@ -19,7 +19,7 @@ class GFPGAN:
|
|||||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
if not self.gfpgan_model_exists:
|
if not self.gfpgan_model_exists:
|
||||||
log.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
|
logger.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_exists(self):
|
def model_exists(self):
|
||||||
@ -27,7 +27,7 @@ class GFPGAN:
|
|||||||
|
|
||||||
def process(self, image, strength: float, seed: str = None):
|
def process(self, image, strength: float, seed: str = None):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
log.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
|
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
@ -47,13 +47,13 @@ class GFPGAN:
|
|||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
log.error("Error loading GFPGAN:", file=sys.stderr)
|
logger.error("Error loading GFPGAN:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
os.chdir(cwd)
|
os.chdir(cwd)
|
||||||
|
|
||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
log.warning("WARNING: GFPGAN not initialized.")
|
logger.warning("WARNING: GFPGAN not initialized.")
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
class Outcrop(object):
|
class Outcrop(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -82,7 +82,7 @@ class Outcrop(object):
|
|||||||
pixels = extents[direction]
|
pixels = extents[direction]
|
||||||
# round pixels up to the nearest 64
|
# round pixels up to the nearest 64
|
||||||
pixels = math.ceil(pixels / 64) * 64
|
pixels = math.ceil(pixels / 64) * 64
|
||||||
log.info(f"extending image {direction}ward by {pixels} pixels")
|
logger.info(f"extending image {direction}ward by {pixels} pixels")
|
||||||
image = self._rotate(image, direction)
|
image = self._rotate(image, direction)
|
||||||
image = self._extend(image, pixels)
|
image = self._extend(image, pixels)
|
||||||
image = self._rotate(image, direction, reverse=True)
|
image = self._rotate(image, direction, reverse=True)
|
||||||
|
@ -6,7 +6,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
class ESRGAN:
|
class ESRGAN:
|
||||||
@ -69,15 +69,15 @@ class ESRGAN:
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
log.error("Error loading Real-ESRGAN:")
|
logger.error("Error loading Real-ESRGAN:")
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
if upsampler_scale == 0:
|
if upsampler_scale == 0:
|
||||||
log.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
||||||
return image
|
return image
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
log.info(
|
logger.info(
|
||||||
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
||||||
)
|
)
|
||||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
||||||
|
@ -14,7 +14,7 @@ from PIL import Image, ImageFilter
|
|||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
import invokeai.assets.web as web_assets
|
import invokeai.assets.web as web_assets
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from .globals import global_cache_dir
|
from .globals import global_cache_dir
|
||||||
from .util import CPU_DEVICE
|
from .util import CPU_DEVICE
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ class SafetyChecker(object):
|
|||||||
cache_dir=safety_model_path,
|
cache_dir=safety_model_path,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.error(
|
logger.error(
|
||||||
"An error was encountered while installing the safety checker:"
|
"An error was encountered while installing the safety checker:"
|
||||||
)
|
)
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
@ -66,7 +66,7 @@ class SafetyChecker(object):
|
|||||||
)
|
)
|
||||||
self.safety_checker.to(CPU_DEVICE) # offload
|
self.safety_checker.to(CPU_DEVICE) # offload
|
||||||
if has_nsfw_concept[0]:
|
if has_nsfw_concept[0]:
|
||||||
log.warning(
|
logger.warning(
|
||||||
"An image with potential non-safe content has been detected. A blurred image will be returned."
|
"An image with potential non-safe content has been detected. A blurred image will be returned."
|
||||||
)
|
)
|
||||||
return self.blur(image)
|
return self.blur(image)
|
||||||
|
@ -17,7 +17,7 @@ from huggingface_hub import (
|
|||||||
hf_hub_url,
|
hf_hub_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
@ -67,10 +67,10 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||||
self.concept_list.extend(list(local_concepts_to_add))
|
self.concept_list.extend(list(local_concepts_to_add))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||||
)
|
)
|
||||||
log.warning(
|
logger.warning(
|
||||||
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||||
)
|
)
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
@ -84,7 +84,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
be downloaded.
|
be downloaded.
|
||||||
"""
|
"""
|
||||||
if not concept_name in self.list_concepts():
|
if not concept_name in self.list_concepts():
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
@ -222,7 +222,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
if chunk == 0:
|
if chunk == 0:
|
||||||
bytes += total
|
bytes += total
|
||||||
|
|
||||||
log.info(f"Downloading {repo_id}...", end="")
|
logger.info(f"Downloading {repo_id}...", end="")
|
||||||
try:
|
try:
|
||||||
for file in (
|
for file in (
|
||||||
"README.md",
|
"README.md",
|
||||||
@ -236,22 +236,22 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
)
|
)
|
||||||
except ul_error.HTTPError as e:
|
except ul_error.HTTPError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
||||||
)
|
)
|
||||||
os.rmdir(dest)
|
os.rmdir(dest)
|
||||||
return False
|
return False
|
||||||
except ul_error.URLError as e:
|
except ul_error.URLError as e:
|
||||||
log.error(
|
logger.error(
|
||||||
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||||
)
|
)
|
||||||
os.rmdir(dest)
|
os.rmdir(dest)
|
||||||
return False
|
return False
|
||||||
log.info("...{:.2f}Kb".format(bytes / 1024))
|
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
||||||
return succeeded
|
return succeeded
|
||||||
|
|
||||||
def _concept_id(self, concept_name: str) -> str:
|
def _concept_id(self, concept_name: str) -> str:
|
||||||
|
@ -13,7 +13,7 @@ from compel.cross_attention_control import Arguments
|
|||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from ...util import torch_dtype
|
from ...util import torch_dtype
|
||||||
|
|
||||||
class CrossAttentionType(enum.Enum):
|
class CrossAttentionType(enum.Enum):
|
||||||
@ -421,7 +421,7 @@ def get_cross_attention_modules(
|
|||||||
expected_count = 16
|
expected_count = 16
|
||||||
if cross_attention_modules_in_model_count != expected_count:
|
if cross_attention_modules_in_model_count != expected_count:
|
||||||
# non-fatal error but .swap() won't work.
|
# non-fatal error but .swap() won't work.
|
||||||
log.error(
|
logger.error(
|
||||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
||||||
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||||
|
@ -8,7 +8,7 @@ import torch
|
|||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
@ -467,13 +467,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
outside = torch.count_nonzero(
|
outside = torch.count_nonzero(
|
||||||
(latents < -current_threshold) | (latents > current_threshold)
|
(latents < -current_threshold) | (latents > current_threshold)
|
||||||
)
|
)
|
||||||
log.info(
|
logger.info(
|
||||||
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
|
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
|
||||||
)
|
)
|
||||||
log.debug(
|
logger.debug(
|
||||||
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
|
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
|
||||||
)
|
)
|
||||||
log.debug(
|
logger.debug(
|
||||||
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
|
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -501,10 +501,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.debug_thresholding:
|
if self.debug_thresholding:
|
||||||
log.debug(
|
logger.debug(
|
||||||
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
|
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
|
||||||
)
|
)
|
||||||
log.debug(
|
logger.debug(
|
||||||
f"{num_altered / latents.numel() * 100:.2f}% values altered"
|
f"{num_altered / latents.numel() * 100:.2f}% values altered"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ from torchvision.utils import make_grid
|
|||||||
|
|
||||||
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||||
|
|
||||||
|
|
||||||
@ -191,7 +191,7 @@ def mkdirs(paths):
|
|||||||
def mkdir_and_rename(path):
|
def mkdir_and_rename(path):
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
new_name = path + "_archived_" + get_timestamp()
|
new_name = path + "_archived_" + get_timestamp()
|
||||||
log.error("Path already exists. Rename it to [{:s}]".format(new_name))
|
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
|
||||||
os.replace(path, new_name)
|
os.replace(path, new_name)
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ from compel.embeddings_provider import BaseTextualInversionManager
|
|||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -60,12 +60,12 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||||
): # in case a token with literal angle brackets encountered
|
): # in case a token with literal angle brackets encountered
|
||||||
log.info(f"Loaded local embedding for trigger {concept_name}")
|
logger.info(f"Loaded local embedding for trigger {concept_name}")
|
||||||
continue
|
continue
|
||||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||||
if not bin_file:
|
if not bin_file:
|
||||||
continue
|
continue
|
||||||
log.info(f"Loaded remote embedding for trigger {concept_name}")
|
logger.info(f"Loaded remote embedding for trigger {concept_name}")
|
||||||
self.load_textual_inversion(bin_file)
|
self.load_textual_inversion(bin_file)
|
||||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||||
for embedding_info in embedding_list:
|
for embedding_info in embedding_list:
|
||||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@ -106,7 +106,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if ckpt_path.name == "learned_embeds.bin"
|
if ckpt_path.name == "learned_embeds.bin"
|
||||||
else f"<{ckpt_path.stem}>"
|
else f"<{ckpt_path.stem}>"
|
||||||
)
|
)
|
||||||
log.info(
|
logger.info(
|
||||||
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||||
)
|
)
|
||||||
trigger_str = replacement_trigger_str
|
trigger_str = replacement_trigger_str
|
||||||
@ -121,8 +121,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
log.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
|
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
|
||||||
log.debug(f"The error was {str(e)}")
|
logger.debug(f"The error was {str(e)}")
|
||||||
|
|
||||||
def _add_textual_inversion(
|
def _add_textual_inversion(
|
||||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||||
@ -134,7 +134,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
:return: The token id for the added embedding, either existing or newly-added.
|
:return: The token id for the added embedding, either existing or newly-added.
|
||||||
"""
|
"""
|
||||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||||
log.warning(
|
logger.warning(
|
||||||
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -156,10 +156,10 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if str(e).startswith("Warning"):
|
if str(e).startswith("Warning"):
|
||||||
log.warning(f"{str(e)}")
|
logger.warning(f"{str(e)}")
|
||||||
else:
|
else:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
log.error(
|
logger.error(
|
||||||
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
@ -220,16 +220,16 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
for ti in self.textual_inversions:
|
for ti in self.textual_inversions:
|
||||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||||
if ti.embedding_vector_length > 1:
|
if ti.embedding_vector_length > 1:
|
||||||
log.info(
|
logger.info(
|
||||||
f"Preparing tokens for textual inversion {ti.trigger_string}..."
|
f"Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self._inject_tokens_and_assign_embeddings(ti)
|
self._inject_tokens_and_assign_embeddings(ti)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
log.debug(
|
logger.debug(
|
||||||
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
|
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||||
)
|
)
|
||||||
log.debug(f"The error was {str(e)}")
|
logger.debug(f"The error was {str(e)}")
|
||||||
continue
|
continue
|
||||||
injected_token_ids.append(ti.trigger_token_id)
|
injected_token_ids.append(ti.trigger_token_id)
|
||||||
injected_token_ids.extend(ti.pad_token_ids)
|
injected_token_ids.extend(ti.pad_token_ids)
|
||||||
@ -307,16 +307,16 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if suffix in [".pt",".ckpt",".bin"]:
|
if suffix in [".pt",".ckpt",".bin"]:
|
||||||
scan_result = scan_file_path(embedding_file)
|
scan_result = scan_file_path(embedding_file)
|
||||||
if scan_result.infected_files > 0:
|
if scan_result.infected_files > 0:
|
||||||
log.critical(
|
logger.critical(
|
||||||
f"Security Issues Found in Model: {scan_result.issues_count}"
|
f"Security Issues Found in Model: {scan_result.issues_count}"
|
||||||
)
|
)
|
||||||
log.critical("For your safety, InvokeAI will not load this embed.")
|
logger.critical("For your safety, InvokeAI will not load this embed.")
|
||||||
return list()
|
return list()
|
||||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||||
else:
|
else:
|
||||||
ckpt = safetensors.torch.load_file(embedding_file)
|
ckpt = safetensors.torch.load_file(embedding_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||||
return list()
|
return list()
|
||||||
|
|
||||||
# try to figure out what kind of embedding file it is and parse accordingly
|
# try to figure out what kind of embedding file it is and parse accordingly
|
||||||
@ -335,7 +335,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||||
basename = Path(file_path).stem
|
basename = Path(file_path).stem
|
||||||
log.debug(f'Loading v1 embedding file: {basename}')
|
logger.debug(f'Loading v1 embedding file: {basename}')
|
||||||
|
|
||||||
embeddings = list()
|
embeddings = list()
|
||||||
token_counter = -1
|
token_counter = -1
|
||||||
@ -366,7 +366,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
This handles embedding .pt file variant #2.
|
This handles embedding .pt file variant #2.
|
||||||
"""
|
"""
|
||||||
basename = Path(file_path).stem
|
basename = Path(file_path).stem
|
||||||
log.debug(f'Loading v2 embedding file: {basename}')
|
logger.debug(f'Loading v2 embedding file: {basename}')
|
||||||
embeddings = list()
|
embeddings = list()
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
@ -385,7 +385,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
)
|
)
|
||||||
embeddings.append(embedding_info)
|
embeddings.append(embedding_info)
|
||||||
else:
|
else:
|
||||||
log.warning(f"{basename}: Unrecognized embedding format")
|
logger.warning(f"{basename}: Unrecognized embedding format")
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@ -394,7 +394,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||||
"""
|
"""
|
||||||
basename = Path(file_path).stem
|
basename = Path(file_path).stem
|
||||||
log.debug(f'Loading v3 embedding file: {basename}')
|
logger.debug(f'Loading v3 embedding file: {basename}')
|
||||||
embedding = embedding_ckpt['emb_params']
|
embedding = embedding_ckpt['emb_params']
|
||||||
embedding_info = EmbeddingInfo(
|
embedding_info = EmbeddingInfo(
|
||||||
name = f'<{basename}>',
|
name = f'<{basename}>',
|
||||||
@ -412,11 +412,11 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
basename = Path(filepath).stem
|
basename = Path(filepath).stem
|
||||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||||
|
|
||||||
log.debug(f'Loading v4 embedding file: {short_path}')
|
logger.debug(f'Loading v4 embedding file: {short_path}')
|
||||||
|
|
||||||
embeddings = list()
|
embeddings = list()
|
||||||
if list(embedding_ckpt.keys()) == 0:
|
if list(embedding_ckpt.keys()) == 0:
|
||||||
log.warning(f"Invalid embeddings file: {short_path}")
|
logger.warning(f"Invalid embeddings file: {short_path}")
|
||||||
else:
|
else:
|
||||||
for token,embedding in embedding_ckpt.items():
|
for token,embedding in embedding_ckpt.items():
|
||||||
embedding_info = EmbeddingInfo(
|
embedding_info = EmbeddingInfo(
|
||||||
|
@ -26,7 +26,7 @@ Console messages:
|
|||||||
|
|
||||||
Another way:
|
Another way:
|
||||||
import invokeai.backend.util.logging as ialog
|
import invokeai.backend.util.logging as ialog
|
||||||
ialog.debug('this is a debugging message')
|
ialogger.debug('this is a debugging message')
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ import torch
|
|||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from .devices import torch_dtype
|
from .devices import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ def log_txt_as_img(wh, xc, size=10):
|
|||||||
try:
|
try:
|
||||||
draw.text((0, 0), lines, fill="black", font=font)
|
draw.text((0, 0), lines, fill="black", font=font)
|
||||||
except UnicodeEncodeError:
|
except UnicodeEncodeError:
|
||||||
log.warning("Cant encode string for logging. Skipping.")
|
logger.warning("Cant encode string for logging. Skipping.")
|
||||||
|
|
||||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||||
txts.append(txt)
|
txts.append(txt)
|
||||||
@ -81,7 +81,7 @@ def mean_flat(tensor):
|
|||||||
def count_params(model, verbose=False):
|
def count_params(model, verbose=False):
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
if verbose:
|
if verbose:
|
||||||
log.debug(
|
logger.debug(
|
||||||
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
||||||
)
|
)
|
||||||
return total_params
|
return total_params
|
||||||
@ -133,7 +133,7 @@ def parallel_data_prefetch(
|
|||||||
raise ValueError("list expected but function got ndarray.")
|
raise ValueError("list expected but function got ndarray.")
|
||||||
elif isinstance(data, abc.Iterable):
|
elif isinstance(data, abc.Iterable):
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
log.warning(
|
logger.warning(
|
||||||
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||||
)
|
)
|
||||||
data = list(data.values())
|
data = list(data.values())
|
||||||
@ -176,7 +176,7 @@ def parallel_data_prefetch(
|
|||||||
processes += [p]
|
processes += [p]
|
||||||
|
|
||||||
# start processes
|
# start processes
|
||||||
log.info("Start prefetching...")
|
logger.info("Start prefetching...")
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -195,7 +195,7 @@ def parallel_data_prefetch(
|
|||||||
gather_res[res[0]] = res[1]
|
gather_res[res[0]] = res[1]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("Exception: ", e)
|
logger.error("Exception: ", e)
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.terminate()
|
p.terminate()
|
||||||
|
|
||||||
@ -203,7 +203,7 @@ def parallel_data_prefetch(
|
|||||||
finally:
|
finally:
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.join()
|
p.join()
|
||||||
log.info(f"Prefetching complete. [{time.time() - start} sec.]")
|
logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||||
|
|
||||||
if target_data_type == "ndarray":
|
if target_data_type == "ndarray":
|
||||||
if not isinstance(gather_res[0], np.ndarray):
|
if not isinstance(gather_res[0], np.ndarray):
|
||||||
@ -319,23 +319,23 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
resp = requests.get(url, headers=header, stream=True) # new request with range
|
resp = requests.get(url, headers=header, stream=True) # new request with range
|
||||||
|
|
||||||
if exist_size > content_length:
|
if exist_size > content_length:
|
||||||
log.warning("corrupt existing file found. re-downloading")
|
logger.warning("corrupt existing file found. re-downloading")
|
||||||
os.remove(dest)
|
os.remove(dest)
|
||||||
exist_size = 0
|
exist_size = 0
|
||||||
|
|
||||||
if resp.status_code == 416 or exist_size == content_length:
|
if resp.status_code == 416 or exist_size == content_length:
|
||||||
log.warning(f"{dest}: complete file found. Skipping.")
|
logger.warning(f"{dest}: complete file found. Skipping.")
|
||||||
return dest
|
return dest
|
||||||
elif resp.status_code == 206 or exist_size > 0:
|
elif resp.status_code == 206 or exist_size > 0:
|
||||||
log.warning(f"{dest}: partial file found. Resuming...")
|
logger.warning(f"{dest}: partial file found. Resuming...")
|
||||||
elif resp.status_code != 200:
|
elif resp.status_code != 200:
|
||||||
log.error(f"An error occurred during downloading {dest}: {resp.reason}")
|
logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
|
||||||
else:
|
else:
|
||||||
log.error(f"{dest}: Downloading...")
|
logger.error(f"{dest}: Downloading...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if content_length < 2000:
|
if content_length < 2000:
|
||||||
log.error(f"ERROR DOWNLOADING {url}: {resp.text}")
|
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with open(dest, open_mode) as file, tqdm(
|
with open(dest, open_mode) as file, tqdm(
|
||||||
@ -350,7 +350,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
size = file.write(data)
|
size = file.write(data)
|
||||||
bar.update(size)
|
bar.update(size)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"An error occurred while downloading {dest}: {str(e)}")
|
logger.error(f"An error occurred while downloading {dest}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return dest
|
return dest
|
||||||
|
@ -19,7 +19,7 @@ from PIL import Image
|
|||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
import invokeai.frontend.web.dist as frontend
|
import invokeai.frontend.web.dist as frontend
|
||||||
|
|
||||||
from .. import Generate
|
from .. import Generate
|
||||||
@ -214,7 +214,7 @@ class InvokeAIWebServer:
|
|||||||
self.load_socketio_listeners(self.socketio)
|
self.load_socketio_listeners(self.socketio)
|
||||||
|
|
||||||
if args.gui:
|
if args.gui:
|
||||||
log.info("Launching Invoke AI GUI")
|
logger.info("Launching Invoke AI GUI")
|
||||||
try:
|
try:
|
||||||
from flaskwebgui import FlaskUI
|
from flaskwebgui import FlaskUI
|
||||||
|
|
||||||
@ -232,16 +232,16 @@ class InvokeAIWebServer:
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
else:
|
else:
|
||||||
useSSL = args.certfile or args.keyfile
|
useSSL = args.certfile or args.keyfile
|
||||||
log.info("Started Invoke AI Web Server")
|
logger.info("Started Invoke AI Web Server")
|
||||||
if self.host == "0.0.0.0":
|
if self.host == "0.0.0.0":
|
||||||
log.info(
|
logger.info(
|
||||||
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
|
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.info(
|
logger.info(
|
||||||
"Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
"Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
||||||
)
|
)
|
||||||
log.info(
|
logger.info(
|
||||||
f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
|
f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
|
||||||
)
|
)
|
||||||
if not useSSL:
|
if not useSSL:
|
||||||
@ -274,7 +274,7 @@ class InvokeAIWebServer:
|
|||||||
# path for thumbnail images
|
# path for thumbnail images
|
||||||
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
|
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
|
||||||
# txt log
|
# txt log
|
||||||
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
|
self.log_path = os.path.join(self.result_path, "invoke_logger.txt")
|
||||||
# make all output paths
|
# make all output paths
|
||||||
[
|
[
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
@ -291,7 +291,7 @@ class InvokeAIWebServer:
|
|||||||
def load_socketio_listeners(self, socketio):
|
def load_socketio_listeners(self, socketio):
|
||||||
@socketio.on("requestSystemConfig")
|
@socketio.on("requestSystemConfig")
|
||||||
def handle_request_capabilities():
|
def handle_request_capabilities():
|
||||||
log.info("System config requested")
|
logger.info("System config requested")
|
||||||
config = self.get_system_config()
|
config = self.get_system_config()
|
||||||
config["model_list"] = self.generate.model_manager.list_models()
|
config["model_list"] = self.generate.model_manager.list_models()
|
||||||
config["infill_methods"] = infill_methods()
|
config["infill_methods"] = infill_methods()
|
||||||
@ -331,7 +331,7 @@ class InvokeAIWebServer:
|
|||||||
if model_name in current_model_list:
|
if model_name in current_model_list:
|
||||||
update = True
|
update = True
|
||||||
|
|
||||||
log.info(f"Adding New Model: {model_name}")
|
logger.info(f"Adding New Model: {model_name}")
|
||||||
|
|
||||||
self.generate.model_manager.add_model(
|
self.generate.model_manager.add_model(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -349,14 +349,14 @@ class InvokeAIWebServer:
|
|||||||
"update": update,
|
"update": update,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
log.info(f"New Model Added: {model_name}")
|
logger.info(f"New Model Added: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@socketio.on("deleteModel")
|
@socketio.on("deleteModel")
|
||||||
def handle_delete_model(model_name: str):
|
def handle_delete_model(model_name: str):
|
||||||
try:
|
try:
|
||||||
log.info(f"Deleting Model: {model_name}")
|
logger.info(f"Deleting Model: {model_name}")
|
||||||
self.generate.model_manager.del_model(model_name)
|
self.generate.model_manager.del_model(model_name)
|
||||||
self.generate.model_manager.commit(opt.conf)
|
self.generate.model_manager.commit(opt.conf)
|
||||||
updated_model_list = self.generate.model_manager.list_models()
|
updated_model_list = self.generate.model_manager.list_models()
|
||||||
@ -367,14 +367,14 @@ class InvokeAIWebServer:
|
|||||||
"model_list": updated_model_list,
|
"model_list": updated_model_list,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
log.info(f"Model Deleted: {model_name}")
|
logger.info(f"Model Deleted: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@socketio.on("requestModelChange")
|
@socketio.on("requestModelChange")
|
||||||
def handle_set_model(model_name: str):
|
def handle_set_model(model_name: str):
|
||||||
try:
|
try:
|
||||||
log.info(f"Model change requested: {model_name}")
|
logger.info(f"Model change requested: {model_name}")
|
||||||
model = self.generate.set_model(model_name)
|
model = self.generate.set_model(model_name)
|
||||||
model_list = self.generate.model_manager.list_models()
|
model_list = self.generate.model_manager.list_models()
|
||||||
if model is None:
|
if model is None:
|
||||||
@ -455,7 +455,7 @@ class InvokeAIWebServer:
|
|||||||
"update": True,
|
"update": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
log.info(f"Model Converted: {model_name}")
|
logger.info(f"Model Converted: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@ -491,7 +491,7 @@ class InvokeAIWebServer:
|
|||||||
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||||
"vae", None
|
"vae", None
|
||||||
):
|
):
|
||||||
log.info(f"Using configured VAE assigned to {models_to_merge[0]}")
|
logger.info(f"Using configured VAE assigned to {models_to_merge[0]}")
|
||||||
merged_model_config.update(vae=vae)
|
merged_model_config.update(vae=vae)
|
||||||
|
|
||||||
self.generate.model_manager.import_diffuser_model(
|
self.generate.model_manager.import_diffuser_model(
|
||||||
@ -508,8 +508,8 @@ class InvokeAIWebServer:
|
|||||||
"update": True,
|
"update": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
log.info(f"Models Merged: {models_to_merge}")
|
logger.info(f"Models Merged: {models_to_merge}")
|
||||||
log.info(f"New Model Added: {model_merge_info['merged_model_name']}")
|
logger.info(f"New Model Added: {model_merge_info['merged_model_name']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@ -699,7 +699,7 @@ class InvokeAIWebServer:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.info(f"Unable to load {path}")
|
logger.info(f"Unable to load {path}")
|
||||||
socketio.emit(
|
socketio.emit(
|
||||||
"error", {"message": f"Unable to load {path}: {str(e)}"}
|
"error", {"message": f"Unable to load {path}: {str(e)}"}
|
||||||
)
|
)
|
||||||
@ -736,9 +736,9 @@ class InvokeAIWebServer:
|
|||||||
printable_parameters["init_mask"][:64] + "..."
|
printable_parameters["init_mask"][:64] + "..."
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"Image Generation Parameters:\n\n{printable_parameters}\n")
|
logger.info(f"Image Generation Parameters:\n\n{printable_parameters}\n")
|
||||||
log.info(f"ESRGAN Parameters: {esrgan_parameters}")
|
logger.info(f"ESRGAN Parameters: {esrgan_parameters}")
|
||||||
log.info(f"Facetool Parameters: {facetool_parameters}")
|
logger.info(f"Facetool Parameters: {facetool_parameters}")
|
||||||
|
|
||||||
self.generate_images(
|
self.generate_images(
|
||||||
generation_parameters,
|
generation_parameters,
|
||||||
@ -751,7 +751,7 @@ class InvokeAIWebServer:
|
|||||||
@socketio.on("runPostprocessing")
|
@socketio.on("runPostprocessing")
|
||||||
def handle_run_postprocessing(original_image, postprocessing_parameters):
|
def handle_run_postprocessing(original_image, postprocessing_parameters):
|
||||||
try:
|
try:
|
||||||
log.info(
|
logger.info(
|
||||||
f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -862,14 +862,14 @@ class InvokeAIWebServer:
|
|||||||
|
|
||||||
@socketio.on("cancel")
|
@socketio.on("cancel")
|
||||||
def handle_cancel():
|
def handle_cancel():
|
||||||
log.info("Cancel processing requested")
|
logger.info("Cancel processing requested")
|
||||||
self.canceled.set()
|
self.canceled.set()
|
||||||
|
|
||||||
# TODO: I think this needs a safety mechanism.
|
# TODO: I think this needs a safety mechanism.
|
||||||
@socketio.on("deleteImage")
|
@socketio.on("deleteImage")
|
||||||
def handle_delete_image(url, thumbnail, uuid, category):
|
def handle_delete_image(url, thumbnail, uuid, category):
|
||||||
try:
|
try:
|
||||||
log.info(f'Delete requested "{url}"')
|
logger.info(f'Delete requested "{url}"')
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
path = self.get_image_path_from_url(url)
|
path = self.get_image_path_from_url(url)
|
||||||
@ -1264,7 +1264,7 @@ class InvokeAIWebServer:
|
|||||||
image, os.path.basename(path), self.thumbnail_image_path
|
image, os.path.basename(path), self.thumbnail_image_path
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f'Image generated: "{path}"\n')
|
logger.info(f'Image generated: "{path}"\n')
|
||||||
self.write_log_message(f'[Generated] "{path}": {command}')
|
self.write_log_message(f'[Generated] "{path}": {command}')
|
||||||
|
|
||||||
if progress.total_iterations > progress.current_iteration:
|
if progress.total_iterations > progress.current_iteration:
|
||||||
@ -1330,7 +1330,7 @@ class InvokeAIWebServer:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Clear the CUDA cache on an exception
|
# Clear the CUDA cache on an exception
|
||||||
self.empty_cuda_cache()
|
self.empty_cuda_cache()
|
||||||
log.error(e)
|
logger.error(e)
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
def empty_cuda_cache(self):
|
def empty_cuda_cache(self):
|
||||||
|
@ -16,7 +16,7 @@ if sys.platform == "darwin":
|
|||||||
import pyparsing # type: ignore
|
import pyparsing # type: ignore
|
||||||
|
|
||||||
import invokeai.version as invokeai
|
import invokeai.version as invokeai
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
from ...backend import Generate, ModelManager
|
from ...backend import Generate, ModelManager
|
||||||
from ...backend.args import Args, dream_cmd_from_png, metadata_dumps, metadata_from_png
|
from ...backend.args import Args, dream_cmd_from_png, metadata_dumps, metadata_from_png
|
||||||
@ -70,7 +70,7 @@ def main():
|
|||||||
# run any post-install patches needed
|
# run any post-install patches needed
|
||||||
run_patches()
|
run_patches()
|
||||||
|
|
||||||
log.info(f"Internet connectivity is {Globals.internet_available}")
|
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
||||||
|
|
||||||
if not args.conf:
|
if not args.conf:
|
||||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||||
@ -79,8 +79,8 @@ def main():
|
|||||||
opt, FileNotFoundError(f"The file {config_file} could not be found.")
|
opt, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"{invokeai.__app_name__}, version {invokeai.__version__}")
|
logger.info(f"{invokeai.__app_name__}, version {invokeai.__version__}")
|
||||||
log.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||||
|
|
||||||
# loading here to avoid long delays on startup
|
# loading here to avoid long delays on startup
|
||||||
# these two lines prevent a horrible warning message from appearing
|
# these two lines prevent a horrible warning message from appearing
|
||||||
@ -122,7 +122,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"{opt.infile} not found.")
|
raise FileNotFoundError(f"{opt.infile} not found.")
|
||||||
except (FileNotFoundError, IOError) as e:
|
except (FileNotFoundError, IOError) as e:
|
||||||
log.critical('Aborted',exc_info=True)
|
logger.critical('Aborted',exc_info=True)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
# creating a Generate object:
|
# creating a Generate object:
|
||||||
@ -144,11 +144,11 @@ def main():
|
|||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
report_model_error(opt, e)
|
report_model_error(opt, e)
|
||||||
except (IOError, KeyError):
|
except (IOError, KeyError):
|
||||||
log.critical("Aborted",exc_info=True)
|
logger.critical("Aborted",exc_info=True)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
if opt.seamless:
|
if opt.seamless:
|
||||||
log.info("Changed to seamless tiling mode")
|
logger.info("Changed to seamless tiling mode")
|
||||||
|
|
||||||
# preload the model
|
# preload the model
|
||||||
try:
|
try:
|
||||||
@ -181,7 +181,7 @@ def main():
|
|||||||
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
|
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.error("An error occurred",exc_info=True)
|
logger.error("An error occurred",exc_info=True)
|
||||||
|
|
||||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||||
def main_loop(gen, opt):
|
def main_loop(gen, opt):
|
||||||
@ -247,7 +247,7 @@ def main_loop(gen, opt):
|
|||||||
if not opt.prompt:
|
if not opt.prompt:
|
||||||
oldargs = metadata_from_png(opt.init_img)
|
oldargs = metadata_from_png(opt.init_img)
|
||||||
opt.prompt = oldargs.prompt
|
opt.prompt = oldargs.prompt
|
||||||
log.info(f'Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
|
logger.info(f'Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
|
||||||
except (OSError, AttributeError, KeyError):
|
except (OSError, AttributeError, KeyError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -264,9 +264,9 @@ def main_loop(gen, opt):
|
|||||||
if opt.init_img is not None and re.match("^-\\d+$", opt.init_img):
|
if opt.init_img is not None and re.match("^-\\d+$", opt.init_img):
|
||||||
try:
|
try:
|
||||||
opt.init_img = last_results[int(opt.init_img)][0]
|
opt.init_img = last_results[int(opt.init_img)][0]
|
||||||
log.info(f"Reusing previous image {opt.init_img}")
|
logger.info(f"Reusing previous image {opt.init_img}")
|
||||||
except IndexError:
|
except IndexError:
|
||||||
log.info(f"No previous initial image at position {opt.init_img} found")
|
logger.info(f"No previous initial image at position {opt.init_img} found")
|
||||||
opt.init_img = None
|
opt.init_img = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -287,9 +287,9 @@ def main_loop(gen, opt):
|
|||||||
if opt.seed is not None and opt.seed < 0 and operation != "postprocess":
|
if opt.seed is not None and opt.seed < 0 and operation != "postprocess":
|
||||||
try:
|
try:
|
||||||
opt.seed = last_results[opt.seed][1]
|
opt.seed = last_results[opt.seed][1]
|
||||||
log.info(f"Reusing previous seed {opt.seed}")
|
logger.info(f"Reusing previous seed {opt.seed}")
|
||||||
except IndexError:
|
except IndexError:
|
||||||
log.info(f"No previous seed at position {opt.seed} found")
|
logger.info(f"No previous seed at position {opt.seed} found")
|
||||||
opt.seed = None
|
opt.seed = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -308,7 +308,7 @@ def main_loop(gen, opt):
|
|||||||
subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))]
|
subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))]
|
||||||
current_outdir = os.path.join(opt.outdir, subdir)
|
current_outdir = os.path.join(opt.outdir, subdir)
|
||||||
|
|
||||||
log.info('Writing files to directory: "' + current_outdir + '"')
|
logger.info('Writing files to directory: "' + current_outdir + '"')
|
||||||
|
|
||||||
# make sure the output directory exists
|
# make sure the output directory exists
|
||||||
if not os.path.exists(current_outdir):
|
if not os.path.exists(current_outdir):
|
||||||
@ -438,13 +438,13 @@ def main_loop(gen, opt):
|
|||||||
**vars(opt),
|
**vars(opt),
|
||||||
)
|
)
|
||||||
except (PromptParser.ParsingException, pyparsing.ParseException):
|
except (PromptParser.ParsingException, pyparsing.ParseException):
|
||||||
log.error("An error occurred while processing your prompt",exc_info=True)
|
logger.error("An error occurred while processing your prompt",exc_info=True)
|
||||||
elif operation == "postprocess":
|
elif operation == "postprocess":
|
||||||
log.info(f"fixing {opt.prompt}")
|
logger.info(f"fixing {opt.prompt}")
|
||||||
opt.last_operation = do_postprocess(gen, opt, image_writer)
|
opt.last_operation = do_postprocess(gen, opt, image_writer)
|
||||||
|
|
||||||
elif operation == "mask":
|
elif operation == "mask":
|
||||||
log.info(f"generating masks from {opt.prompt}")
|
logger.info(f"generating masks from {opt.prompt}")
|
||||||
do_textmask(gen, opt, image_writer)
|
do_textmask(gen, opt, image_writer)
|
||||||
|
|
||||||
if opt.grid and len(grid_images) > 0:
|
if opt.grid and len(grid_images) > 0:
|
||||||
@ -468,11 +468,11 @@ def main_loop(gen, opt):
|
|||||||
results = [[path, formatted_dream_prompt]]
|
results = [[path, formatted_dream_prompt]]
|
||||||
|
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
log.error(e)
|
logger.error(e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
log.error(e)
|
logger.error(e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print("Outputs:")
|
print("Outputs:")
|
||||||
@ -511,7 +511,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
gen.set_model(model_name)
|
gen.set_model(model_name)
|
||||||
add_embedding_terms(gen, completer)
|
add_embedding_terms(gen, completer)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
log.error(e)
|
logger.error(e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
report_model_error(opt, e)
|
report_model_error(opt, e)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
@ -525,7 +525,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
elif command.startswith("!import"):
|
elif command.startswith("!import"):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
log.warning(
|
logger.warning(
|
||||||
"please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
|
"please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -539,7 +539,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
elif command.startswith(("!convert", "!optimize")):
|
elif command.startswith(("!convert", "!optimize")):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
log.warning("please provide the path to a .ckpt or .safetensors model")
|
logger.warning("please provide the path to a .ckpt or .safetensors model")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
convert_model(path[1], gen, opt, completer)
|
convert_model(path[1], gen, opt, completer)
|
||||||
@ -551,7 +551,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
elif command.startswith("!edit"):
|
elif command.startswith("!edit"):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
log.warning("please provide the name of a model")
|
logger.warning("please provide the name of a model")
|
||||||
else:
|
else:
|
||||||
edit_model(path[1], gen, opt, completer)
|
edit_model(path[1], gen, opt, completer)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
@ -560,7 +560,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
elif command.startswith("!del"):
|
elif command.startswith("!del"):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
log.warning("please provide the name of a model")
|
logger.warning("please provide the name of a model")
|
||||||
else:
|
else:
|
||||||
del_config(path[1], gen, opt, completer)
|
del_config(path[1], gen, opt, completer)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
@ -641,7 +641,7 @@ def import_model(model_path: str, gen, opt, completer):
|
|||||||
default_name = url_attachment_name(model_path)
|
default_name = url_attachment_name(model_path)
|
||||||
default_name = Path(default_name).stem
|
default_name = Path(default_name).stem
|
||||||
except Exception:
|
except Exception:
|
||||||
log.warning(f"A problem occurred while assigning the name of the downloaded model",exc_info=True)
|
logger.warning(f"A problem occurred while assigning the name of the downloaded model",exc_info=True)
|
||||||
model_name, model_desc = _get_model_name_and_desc(
|
model_name, model_desc = _get_model_name_and_desc(
|
||||||
gen.model_manager,
|
gen.model_manager,
|
||||||
completer,
|
completer,
|
||||||
@ -662,11 +662,11 @@ def import_model(model_path: str, gen, opt, completer):
|
|||||||
model_config_file=config_file,
|
model_config_file=config_file,
|
||||||
)
|
)
|
||||||
if not imported_name:
|
if not imported_name:
|
||||||
log.error("Aborting import.")
|
logger.error("Aborting import.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not _verify_load(imported_name, gen):
|
if not _verify_load(imported_name, gen):
|
||||||
log.error("model failed to load. Discarding configuration entry")
|
logger.error("model failed to load. Discarding configuration entry")
|
||||||
gen.model_manager.del_model(imported_name)
|
gen.model_manager.del_model(imported_name)
|
||||||
return
|
return
|
||||||
if click.confirm("Make this the default model?", default=False):
|
if click.confirm("Make this the default model?", default=False):
|
||||||
@ -674,7 +674,7 @@ def import_model(model_path: str, gen, opt, completer):
|
|||||||
|
|
||||||
gen.model_manager.commit(opt.conf)
|
gen.model_manager.commit(opt.conf)
|
||||||
completer.update_models(gen.model_manager.list_models())
|
completer.update_models(gen.model_manager.list_models())
|
||||||
log.info(f"{imported_name} successfully installed")
|
logger.info(f"{imported_name} successfully installed")
|
||||||
|
|
||||||
def _pick_configuration_file(completer)->Path:
|
def _pick_configuration_file(completer)->Path:
|
||||||
print(
|
print(
|
||||||
@ -718,21 +718,21 @@ Please select the type of this model:
|
|||||||
return choice
|
return choice
|
||||||
|
|
||||||
def _verify_load(model_name: str, gen) -> bool:
|
def _verify_load(model_name: str, gen) -> bool:
|
||||||
log.info("Verifying that new model loads...")
|
logger.info("Verifying that new model loads...")
|
||||||
current_model = gen.model_name
|
current_model = gen.model_name
|
||||||
try:
|
try:
|
||||||
if not gen.set_model(model_name):
|
if not gen.set_model(model_name):
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"model failed to load: {str(e)}")
|
logger.warning(f"model failed to load: {str(e)}")
|
||||||
log.warning(
|
logger.warning(
|
||||||
"** note that importing 2.X checkpoints is not supported. Please use !convert_model instead."
|
"** note that importing 2.X checkpoints is not supported. Please use !convert_model instead."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
if click.confirm("Keep model loaded?", default=True):
|
if click.confirm("Keep model loaded?", default=True):
|
||||||
gen.set_model(model_name)
|
gen.set_model(model_name)
|
||||||
else:
|
else:
|
||||||
log.info("Restoring previous model")
|
logger.info("Restoring previous model")
|
||||||
gen.set_model(current_model)
|
gen.set_model(current_model)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -755,7 +755,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
ckpt_path = None
|
ckpt_path = None
|
||||||
original_config_file = None
|
original_config_file = None
|
||||||
if model_name_or_path == gen.model_name:
|
if model_name_or_path == gen.model_name:
|
||||||
log.warning("Can't convert the active model. !switch to another model first. **")
|
logger.warning("Can't convert the active model. !switch to another model first. **")
|
||||||
return
|
return
|
||||||
elif model_info := manager.model_info(model_name_or_path):
|
elif model_info := manager.model_info(model_name_or_path):
|
||||||
if "weights" in model_info:
|
if "weights" in model_info:
|
||||||
@ -765,7 +765,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
model_description = model_info["description"]
|
model_description = model_info["description"]
|
||||||
vae_path = model_info.get("vae")
|
vae_path = model_info.get("vae")
|
||||||
else:
|
else:
|
||||||
log.warning(f"{model_name_or_path} is not a legacy .ckpt weights file")
|
logger.warning(f"{model_name_or_path} is not a legacy .ckpt weights file")
|
||||||
return
|
return
|
||||||
model_name = manager.convert_and_import(
|
model_name = manager.convert_and_import(
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
@ -786,16 +786,16 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
manager.commit(opt.conf)
|
manager.commit(opt.conf)
|
||||||
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
|
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
|
||||||
ckpt_path.unlink(missing_ok=True)
|
ckpt_path.unlink(missing_ok=True)
|
||||||
log.warning(f"{ckpt_path} deleted")
|
logger.warning(f"{ckpt_path} deleted")
|
||||||
|
|
||||||
|
|
||||||
def del_config(model_name: str, gen, opt, completer):
|
def del_config(model_name: str, gen, opt, completer):
|
||||||
current_model = gen.model_name
|
current_model = gen.model_name
|
||||||
if model_name == current_model:
|
if model_name == current_model:
|
||||||
log.warning("Can't delete active model. !switch to another model first. **")
|
logger.warning("Can't delete active model. !switch to another model first. **")
|
||||||
return
|
return
|
||||||
if model_name not in gen.model_manager.config:
|
if model_name not in gen.model_manager.config:
|
||||||
log.warning(f"Unknown model {model_name}")
|
logger.warning(f"Unknown model {model_name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not click.confirm(
|
if not click.confirm(
|
||||||
@ -808,17 +808,17 @@ def del_config(model_name: str, gen, opt, completer):
|
|||||||
)
|
)
|
||||||
gen.model_manager.del_model(model_name, delete_files=delete_completely)
|
gen.model_manager.del_model(model_name, delete_files=delete_completely)
|
||||||
gen.model_manager.commit(opt.conf)
|
gen.model_manager.commit(opt.conf)
|
||||||
log.warning(f"{model_name} deleted")
|
logger.warning(f"{model_name} deleted")
|
||||||
completer.update_models(gen.model_manager.list_models())
|
completer.update_models(gen.model_manager.list_models())
|
||||||
|
|
||||||
|
|
||||||
def edit_model(model_name: str, gen, opt, completer):
|
def edit_model(model_name: str, gen, opt, completer):
|
||||||
manager = gen.model_manager
|
manager = gen.model_manager
|
||||||
if not (info := manager.model_info(model_name)):
|
if not (info := manager.model_info(model_name)):
|
||||||
log.warning(f"** Unknown model {model_name}")
|
logger.warning(f"** Unknown model {model_name}")
|
||||||
return
|
return
|
||||||
print()
|
print()
|
||||||
log.info(f"Editing model {model_name} from configuration file {opt.conf}")
|
logger.info(f"Editing model {model_name} from configuration file {opt.conf}")
|
||||||
new_name = _get_model_name(manager.list_models(), completer, model_name)
|
new_name = _get_model_name(manager.list_models(), completer, model_name)
|
||||||
|
|
||||||
for attribute in info.keys():
|
for attribute in info.keys():
|
||||||
@ -856,7 +856,7 @@ def edit_model(model_name: str, gen, opt, completer):
|
|||||||
manager.set_default_model(new_name)
|
manager.set_default_model(new_name)
|
||||||
manager.commit(opt.conf)
|
manager.commit(opt.conf)
|
||||||
completer.update_models(manager.list_models())
|
completer.update_models(manager.list_models())
|
||||||
log.info("Model successfully updated")
|
logger.info("Model successfully updated")
|
||||||
|
|
||||||
|
|
||||||
def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
||||||
@ -867,11 +867,11 @@ def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
|||||||
if len(model_name) == 0:
|
if len(model_name) == 0:
|
||||||
model_name = default_name
|
model_name = default_name
|
||||||
if not re.match("^[\w._+:/-]+$", model_name):
|
if not re.match("^[\w._+:/-]+$", model_name):
|
||||||
log.warning(
|
logger.warning(
|
||||||
'model name must contain only words, digits and the characters "._+:/-" **'
|
'model name must contain only words, digits and the characters "._+:/-" **'
|
||||||
)
|
)
|
||||||
elif model_name != default_name and model_name in existing_names:
|
elif model_name != default_name and model_name in existing_names:
|
||||||
log.warning(f"the name {model_name} is already in use. Pick another.")
|
logger.warning(f"the name {model_name} is already in use. Pick another.")
|
||||||
else:
|
else:
|
||||||
done = True
|
done = True
|
||||||
return model_name
|
return model_name
|
||||||
@ -938,10 +938,10 @@ def do_postprocess(gen, opt, callback):
|
|||||||
opt=opt,
|
opt=opt,
|
||||||
)
|
)
|
||||||
except OSError:
|
except OSError:
|
||||||
log.error(f"{file_path}: file could not be read",exc_info=True)
|
logger.error(f"{file_path}: file could not be read",exc_info=True)
|
||||||
return
|
return
|
||||||
except (KeyError, AttributeError):
|
except (KeyError, AttributeError):
|
||||||
log.error(f"an error occurred while applying the {tool} postprocessor",exc_info=True)
|
logger.error(f"an error occurred while applying the {tool} postprocessor",exc_info=True)
|
||||||
return
|
return
|
||||||
return opt.last_operation
|
return opt.last_operation
|
||||||
|
|
||||||
@ -996,12 +996,12 @@ def prepare_image_metadata(
|
|||||||
try:
|
try:
|
||||||
filename = opt.fnformat.format(**wildcards)
|
filename = opt.fnformat.format(**wildcards)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
log.error(
|
logger.error(
|
||||||
f"The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
|
f"The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
|
||||||
)
|
)
|
||||||
filename = f"{prefix}.{seed}.png"
|
filename = f"{prefix}.{seed}.png"
|
||||||
except IndexError:
|
except IndexError:
|
||||||
log.error(
|
logger.error(
|
||||||
"The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
|
"The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
|
||||||
)
|
)
|
||||||
filename = f"{prefix}.{seed}.png"
|
filename = f"{prefix}.{seed}.png"
|
||||||
@ -1091,14 +1091,14 @@ def split_variations(variations_string) -> list:
|
|||||||
for part in variations_string.split(","):
|
for part in variations_string.split(","):
|
||||||
seed_and_weight = part.split(":")
|
seed_and_weight = part.split(":")
|
||||||
if len(seed_and_weight) != 2:
|
if len(seed_and_weight) != 2:
|
||||||
log.warning(f'Could not parse with_variation part "{part}"')
|
logger.warning(f'Could not parse with_variation part "{part}"')
|
||||||
broken = True
|
broken = True
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
seed = int(seed_and_weight[0])
|
seed = int(seed_and_weight[0])
|
||||||
weight = float(seed_and_weight[1])
|
weight = float(seed_and_weight[1])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
log.warning(f'Could not parse with_variation part "{part}"')
|
logger.warning(f'Could not parse with_variation part "{part}"')
|
||||||
broken = True
|
broken = True
|
||||||
break
|
break
|
||||||
parts.append([seed, weight])
|
parts.append([seed, weight])
|
||||||
@ -1122,23 +1122,23 @@ def load_face_restoration(opt):
|
|||||||
opt.gfpgan_model_path
|
opt.gfpgan_model_path
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.info("Face restoration disabled")
|
logger.info("Face restoration disabled")
|
||||||
if opt.esrgan:
|
if opt.esrgan:
|
||||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||||
else:
|
else:
|
||||||
log.info("Upscaling disabled")
|
logger.info("Upscaling disabled")
|
||||||
else:
|
else:
|
||||||
log.info("Face restoration and upscaling disabled")
|
logger.info("Face restoration and upscaling disabled")
|
||||||
except (ModuleNotFoundError, ImportError):
|
except (ModuleNotFoundError, ImportError):
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
log.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
||||||
return gfpgan, codeformer, esrgan
|
return gfpgan, codeformer, esrgan
|
||||||
|
|
||||||
|
|
||||||
def make_step_callback(gen, opt, prefix):
|
def make_step_callback(gen, opt, prefix):
|
||||||
destination = os.path.join(opt.outdir, "intermediates", prefix)
|
destination = os.path.join(opt.outdir, "intermediates", prefix)
|
||||||
os.makedirs(destination, exist_ok=True)
|
os.makedirs(destination, exist_ok=True)
|
||||||
log.info(f"Intermediate images will be written into {destination}")
|
logger.info(f"Intermediate images will be written into {destination}")
|
||||||
|
|
||||||
def callback(state: PipelineIntermediateState):
|
def callback(state: PipelineIntermediateState):
|
||||||
latents = state.latents
|
latents = state.latents
|
||||||
@ -1180,11 +1180,11 @@ def retrieve_dream_command(opt, command, completer):
|
|||||||
try:
|
try:
|
||||||
cmd = dream_cmd_from_png(path)
|
cmd = dream_cmd_from_png(path)
|
||||||
except OSError:
|
except OSError:
|
||||||
log.error(f"{tokens[0]}: file could not be read")
|
logger.error(f"{tokens[0]}: file could not be read")
|
||||||
except (KeyError, AttributeError, IndexError):
|
except (KeyError, AttributeError, IndexError):
|
||||||
log.error(f"{tokens[0]}: file has no metadata")
|
logger.error(f"{tokens[0]}: file has no metadata")
|
||||||
except:
|
except:
|
||||||
log.error(f"{tokens[0]}: file could not be processed")
|
logger.error(f"{tokens[0]}: file could not be processed")
|
||||||
if len(cmd) > 0:
|
if len(cmd) > 0:
|
||||||
completer.set_line(cmd)
|
completer.set_line(cmd)
|
||||||
|
|
||||||
@ -1193,7 +1193,7 @@ def write_commands(opt, file_path: str, outfilepath: str):
|
|||||||
try:
|
try:
|
||||||
paths = sorted(list(Path(dir).glob(basename)))
|
paths = sorted(list(Path(dir).glob(basename)))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
log.error(f'"{basename}": unacceptable pattern')
|
logger.error(f'"{basename}": unacceptable pattern')
|
||||||
return
|
return
|
||||||
|
|
||||||
commands = []
|
commands = []
|
||||||
@ -1202,9 +1202,9 @@ def write_commands(opt, file_path: str, outfilepath: str):
|
|||||||
try:
|
try:
|
||||||
cmd = dream_cmd_from_png(path)
|
cmd = dream_cmd_from_png(path)
|
||||||
except (KeyError, AttributeError, IndexError):
|
except (KeyError, AttributeError, IndexError):
|
||||||
log.error(f"{path}: file has no metadata")
|
logger.error(f"{path}: file has no metadata")
|
||||||
except:
|
except:
|
||||||
log.error(f"{path}: file could not be processed")
|
logger.error(f"{path}: file could not be processed")
|
||||||
if cmd:
|
if cmd:
|
||||||
commands.append(f"# {path}")
|
commands.append(f"# {path}")
|
||||||
commands.append(cmd)
|
commands.append(cmd)
|
||||||
@ -1214,17 +1214,17 @@ def write_commands(opt, file_path: str, outfilepath: str):
|
|||||||
outfilepath = os.path.join(opt.outdir, basename)
|
outfilepath = os.path.join(opt.outdir, basename)
|
||||||
with open(outfilepath, "w", encoding="utf-8") as f:
|
with open(outfilepath, "w", encoding="utf-8") as f:
|
||||||
f.write("\n".join(commands))
|
f.write("\n".join(commands))
|
||||||
log.info(f"File {outfilepath} with commands created")
|
logger.info(f"File {outfilepath} with commands created")
|
||||||
|
|
||||||
|
|
||||||
def report_model_error(opt: Namespace, e: Exception):
|
def report_model_error(opt: Namespace, e: Exception):
|
||||||
log.warning(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
logger.warning(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||||
log.warning(
|
logger.warning(
|
||||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||||
)
|
)
|
||||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||||
if yes_to_all:
|
if yes_to_all:
|
||||||
log.warning(
|
logger.warning(
|
||||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -1234,7 +1234,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
|||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
log.info("invokeai-configure is launching....\n")
|
logger.info("invokeai-configure is launching....\n")
|
||||||
|
|
||||||
# Match arguments that were set on the CLI
|
# Match arguments that were set on the CLI
|
||||||
# only the arguments accepted by the configuration script are parsed
|
# only the arguments accepted by the configuration script are parsed
|
||||||
@ -1251,7 +1251,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
|||||||
from ..install import invokeai_configure
|
from ..install import invokeai_configure
|
||||||
|
|
||||||
invokeai_configure()
|
invokeai_configure()
|
||||||
log.warning("InvokeAI will now restart")
|
logger.warning("InvokeAI will now restart")
|
||||||
sys.argv = previous_args
|
sys.argv = previous_args
|
||||||
main() # would rather do a os.exec(), but doesn't exist?
|
main() # would rather do a os.exec(), but doesn't exist?
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -22,7 +22,7 @@ import torch
|
|||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals, global_config_dir
|
from invokeai.backend.globals import Globals, global_config_dir
|
||||||
|
|
||||||
from ...backend.config.model_install_backend import (
|
from ...backend.config.model_install_backend import (
|
||||||
@ -456,7 +456,7 @@ def main():
|
|||||||
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
||||||
|
|
||||||
if not global_config_dir().exists():
|
if not global_config_dir().exists():
|
||||||
log.info(
|
logger.info(
|
||||||
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||||
)
|
)
|
||||||
from invokeai.frontend.install import invokeai_configure
|
from invokeai.frontend.install import invokeai_configure
|
||||||
@ -467,17 +467,17 @@ def main():
|
|||||||
try:
|
try:
|
||||||
select_and_download_models(opt)
|
select_and_download_models(opt)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
log.error(e)
|
logger.error(e)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
log.info("Goodbye! Come back soon.")
|
logger.info("Goodbye! Come back soon.")
|
||||||
except widget.NotEnoughSpaceForWidget as e:
|
except widget.NotEnoughSpaceForWidget as e:
|
||||||
if str(e).startswith("Height of 1 allocated"):
|
if str(e).startswith("Height of 1 allocated"):
|
||||||
log.error(
|
logger.error(
|
||||||
"Insufficient vertical space for the interface. Please make your window taller and try again"
|
"Insufficient vertical space for the interface. Please make your window taller and try again"
|
||||||
)
|
)
|
||||||
elif str(e).startswith("addwstr"):
|
elif str(e).startswith("addwstr"):
|
||||||
log.error(
|
logger.error(
|
||||||
"Insufficient horizontal space for the interface. Please make your window wider and try again."
|
"Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ from ...backend.globals import (
|
|||||||
global_set_root,
|
global_set_root,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from ...backend.model_management import ModelManager
|
from ...backend.model_management import ModelManager
|
||||||
from ...frontend.install.widgets import FloatTitleSlider
|
from ...frontend.install.widgets import FloatTitleSlider
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ def merge_diffusion_models_and_commit(
|
|||||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||||
)
|
)
|
||||||
if vae := model_manager.config[models[0]].get("vae", None):
|
if vae := model_manager.config[models[0]].get("vae", None):
|
||||||
log.info(f"Using configured VAE assigned to {models[0]}")
|
logger.info(f"Using configured VAE assigned to {models[0]}")
|
||||||
import_args.update(vae=vae)
|
import_args.update(vae=vae)
|
||||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||||
model_manager.commit(config_file)
|
model_manager.commit(config_file)
|
||||||
@ -414,7 +414,7 @@ def run_gui(args: Namespace):
|
|||||||
|
|
||||||
args = mergeapp.merge_arguments
|
args = mergeapp.merge_arguments
|
||||||
merge_diffusion_models_and_commit(**args)
|
merge_diffusion_models_and_commit(**args)
|
||||||
log.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||||
|
|
||||||
|
|
||||||
def run_cli(args: Namespace):
|
def run_cli(args: Namespace):
|
||||||
@ -425,7 +425,7 @@ def run_cli(args: Namespace):
|
|||||||
|
|
||||||
if not args.merged_model_name:
|
if not args.merged_model_name:
|
||||||
args.merged_model_name = "+".join(args.models)
|
args.merged_model_name = "+".join(args.models)
|
||||||
log.info(
|
logger.info(
|
||||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -435,7 +435,7 @@ def run_cli(args: Namespace):
|
|||||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||||
|
|
||||||
merge_diffusion_models_and_commit(**vars(args))
|
merge_diffusion_models_and_commit(**vars(args))
|
||||||
log.info(f'Models merged into new model: "{args.merged_model_name}".')
|
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -455,16 +455,16 @@ def main():
|
|||||||
run_cli(args)
|
run_cli(args)
|
||||||
except widget.NotEnoughSpaceForWidget as e:
|
except widget.NotEnoughSpaceForWidget as e:
|
||||||
if str(e).startswith("Height of 1 allocated"):
|
if str(e).startswith("Height of 1 allocated"):
|
||||||
log.error(
|
logger.error(
|
||||||
"You need to have at least two diffusers models defined in models.yaml in order to merge"
|
"You need to have at least two diffusers models defined in models.yaml in order to merge"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.error(
|
logger.error(
|
||||||
"Not enough room for the user interface. Try making this window larger."
|
"Not enough room for the user interface. Try making this window larger."
|
||||||
)
|
)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(e)
|
logger.error(e)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
@ -20,7 +20,7 @@ import npyscreen
|
|||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as log
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals, global_set_root
|
from invokeai.backend.globals import Globals, global_set_root
|
||||||
|
|
||||||
from ...backend.training import do_textual_inversion_training, parse_args
|
from ...backend.training import do_textual_inversion_training, parse_args
|
||||||
@ -369,14 +369,14 @@ def copy_to_embeddings_folder(args: dict):
|
|||||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||||
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
||||||
os.makedirs(destination, exist_ok=True)
|
os.makedirs(destination, exist_ok=True)
|
||||||
log.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||||
shutil.copy(source, destination)
|
shutil.copy(source, destination)
|
||||||
if (
|
if (
|
||||||
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
|
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
|
||||||
).startswith(("y", "Y")):
|
).startswith(("y", "Y")):
|
||||||
shutil.rmtree(Path(args["output_dir"]))
|
shutil.rmtree(Path(args["output_dir"]))
|
||||||
else:
|
else:
|
||||||
log.info(f'Keeping {args["output_dir"]}')
|
logger.info(f'Keeping {args["output_dir"]}')
|
||||||
|
|
||||||
|
|
||||||
def save_args(args: dict):
|
def save_args(args: dict):
|
||||||
@ -423,10 +423,10 @@ def do_front_end(args: Namespace):
|
|||||||
do_textual_inversion_training(**args)
|
do_textual_inversion_training(**args)
|
||||||
copy_to_embeddings_folder(args)
|
copy_to_embeddings_folder(args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("An exception occurred during training. The exception was:")
|
logger.error("An exception occurred during training. The exception was:")
|
||||||
log.error(str(e))
|
logger.error(str(e))
|
||||||
log.error("DETAILS:")
|
logger.error("DETAILS:")
|
||||||
log.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -438,21 +438,21 @@ def main():
|
|||||||
else:
|
else:
|
||||||
do_textual_inversion_training(**vars(args))
|
do_textual_inversion_training(**vars(args))
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
log.error(e)
|
logger.error(e)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
||||||
if str(e).startswith("Height of 1 allocated"):
|
if str(e).startswith("Height of 1 allocated"):
|
||||||
log.error(
|
logger.error(
|
||||||
"You need to have at least one diffusers models defined in models.yaml in order to train"
|
"You need to have at least one diffusers models defined in models.yaml in order to train"
|
||||||
)
|
)
|
||||||
elif str(e).startswith("addwstr"):
|
elif str(e).startswith("addwstr"):
|
||||||
log.error(
|
logger.error(
|
||||||
"Not enough window space for the interface. Please make your window larger and try again."
|
"Not enough window space for the interface. Please make your window larger and try again."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.error(e)
|
logger.error(e)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user