logger is a interchangeable service

This commit is contained in:
Lincoln Stein
2023-04-29 10:48:50 -04:00
parent 8db20e0d95
commit 974841926d
8 changed files with 108 additions and 100 deletions

View File

@ -67,8 +67,9 @@ class ApiDependencies:
db_location = os.path.join(output_folder, "invokeai.db") db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices( services = InvocationServices(
model_manager=get_model_manager(config), model_manager=get_model_manager(config,logger),
events=events, events=events,
logger=logger,
latents=latents, latents=latents,
images=images, images=images,
metadata=metadata, metadata=metadata,
@ -80,7 +81,7 @@ class ApiDependencies:
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"
), ),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config), restoration=RestorationServices(config,logger),
) )
create_system_graphs(services.graph_library) create_system_graphs(services.graph_library)

View File

@ -4,15 +4,10 @@ import shutil
import asyncio import asyncio
from typing import Annotated, Any, List, Literal, Optional, Union from typing import Annotated, Any, List, Literal, Optional, Union
import invokeai.backend.util.logging as logger
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path from pathlib import Path
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
from invokeai.backend.args import Args
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -113,6 +108,7 @@ async def update_model(
async def delete_model(model_name: str) -> None: async def delete_model(model_name: str) -> None:
"""Delete Model""" """Delete Model"""
model_names = ApiDependencies.invoker.services.model_manager.model_names() model_names = ApiDependencies.invoker.services.model_manager.model_names()
logger = ApiDependencies.invoker.services.logger
model_exists = model_name in model_names model_exists = model_name in model_names
# check if model exists # check if model exists

View File

@ -181,7 +181,7 @@ def invoke_all(context: CliContext):
# Print any errors # Print any errors
if context.session.has_error(): if context.session.has_error():
for n in context.session.errors: for n in context.session.errors:
logger.error( context.invoker.services.logger.error(
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}" f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
) )
@ -191,7 +191,7 @@ def invoke_all(context: CliContext):
def invoke_cli(): def invoke_cli():
config = Args() config = Args()
config.parse_args() config.parse_args()
model_manager = get_model_manager(config) model_manager = get_model_manager(config,logger=logger)
# This initializes the autocompleter and returns it. # This initializes the autocompleter and returns it.
# Currently nothing is done with the returned Completer # Currently nothing is done with the returned Completer
@ -224,7 +224,8 @@ def invoke_cli():
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"
), ),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config), restoration=RestorationServices(config,logger=logger),
logger=logger,
) )
system_graphs = create_system_graphs(services.graph_library) system_graphs = create_system_graphs(services.graph_library)
@ -364,12 +365,12 @@ def invoke_cli():
invoke_all(context) invoke_all(context)
except InvalidArgs: except InvalidArgs:
logger.warning('Invalid command, use "help" to list commands') invoker.services.logger.warning('Invalid command, use "help" to list commands')
continue continue
except SessionError: except SessionError:
# Start a new session # Start a new session
logger.warning("Session error: creating a new session") invoker.services.logger.warning("Session error: creating a new session")
context.reset() context.reset()
except ExitCli: except ExitCli:

View File

@ -1,9 +1,9 @@
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.model_manager import ModelManager from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str): def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model.""" """Returns the default model if the `model_name` not a valid model, else returns the selected model."""
logger = model_manager.logger
if model_manager.valid_model(model_name): if model_manager.valid_model(model_name):
model = model_manager.get_model(model_name) model = model_manager.get_model(model_name)
else: else:

View File

@ -1,4 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import types
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.backend import ModelManager from invokeai.backend import ModelManager
@ -29,6 +31,7 @@ class InvocationServices:
self, self,
model_manager: ModelManager, model_manager: ModelManager,
events: EventServiceBase, events: EventServiceBase,
logger: types.ModuleType,
latents: LatentsStorageBase, latents: LatentsStorageBase,
images: ImageStorageBase, images: ImageStorageBase,
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
@ -40,6 +43,7 @@ class InvocationServices:
): ):
self.model_manager = model_manager self.model_manager = model_manager
self.events = events self.events = events
self.logger = logger
self.latents = latents self.latents = latents
self.images = images self.images = images
self.metadata = metadata self.metadata = metadata

View File

@ -5,20 +5,20 @@ from argparse import Namespace
from invokeai.backend import Args from invokeai.backend import Args
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pathlib import Path from pathlib import Path
from typing import types
import invokeai.version import invokeai.version
import invokeai.backend.util.logging as logger
from ...backend import ModelManager from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals from ...backend import Globals
# TODO: Replace with an abstract class base ModelManagerBase # TODO: Replace with an abstract class base ModelManagerBase
def get_model_manager(config: Args) -> ModelManager: def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
if not config.conf: if not config.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml") config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file): if not os.path.exists(config_file):
report_model_error( report_model_error(
config, FileNotFoundError(f"The file {config_file} could not be found.") config, FileNotFoundError(f"The file {config_file} could not be found."), logger
) )
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}") logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
@ -63,9 +63,10 @@ def get_model_manager(config: Args) -> ModelManager:
device_type=device, device_type=device,
max_loaded_models=config.max_loaded_models, max_loaded_models=config.max_loaded_models,
embedding_path = Path(embedding_path), embedding_path = Path(embedding_path),
logger = logger,
) )
except (FileNotFoundError, TypeError, AssertionError) as e: except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(config, e) report_model_error(config, e, logger)
except (IOError, KeyError) as e: except (IOError, KeyError) as e:
logger.error(f"{e}. Aborting.") logger.error(f"{e}. Aborting.")
sys.exit(-1) sys.exit(-1)
@ -77,17 +78,17 @@ def get_model_manager(config: Args) -> ModelManager:
conf_path=config.conf, conf_path=config.conf,
weights_directory=path, weights_directory=path,
) )
logger.info('Model manager initialized')
return model_manager return model_manager
def report_model_error(opt: Namespace, e: Exception): def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"') logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
logger.error( logger.error(
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models." "This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
) )
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE") yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all: if yes_to_all:
logger.warning logger.warning(
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE" "Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
) )
else: else:
@ -103,7 +104,6 @@ def report_model_error(opt: Namespace, e: Exception):
# only the arguments accepted by the configuration script are parsed # only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else [] root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else [] config = ["--config", opt.conf] if opt.conf is not None else []
previous_config = sys.argv
sys.argv = ["invokeai-configure"] sys.argv = ["invokeai-configure"]
sys.argv.extend(root_dir) sys.argv.extend(root_dir)
sys.argv.extend(config.to_dict()) sys.argv.extend(config.to_dict())

View File

@ -1,7 +1,7 @@
import sys import sys
import traceback import traceback
import torch import torch
import invokeai.backend.util.logging as logger from typing import types
from ...backend.restoration import Restoration from ...backend.restoration import Restoration
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
@ -11,7 +11,7 @@ from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
class RestorationServices: class RestorationServices:
'''Face restoration and upscaling''' '''Face restoration and upscaling'''
def __init__(self,args): def __init__(self,args,logger:types.ModuleType):
try: try:
gfpgan, codeformer, esrgan = None, None, None gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan: if args.restore or args.esrgan:
@ -35,6 +35,8 @@ class RestorationServices:
self.gfpgan = gfpgan self.gfpgan = gfpgan
self.codeformer = codeformer self.codeformer = codeformer
self.esrgan = esrgan self.esrgan = esrgan
self.logger = logger
self.logger.info('Face restoration initialized')
# note that this one method does gfpgan and codepath reconstruction, as well as # note that this one method does gfpgan and codepath reconstruction, as well as
# esrgan upscaling # esrgan upscaling
@ -59,14 +61,14 @@ class RestorationServices:
if self.gfpgan is not None or self.codeformer is not None: if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan": if facetool == "gfpgan":
if self.gfpgan is None: if self.gfpgan is None:
logger.info( self.logger.info(
"GFPGAN not found. Face restoration is disabled." "GFPGAN not found. Face restoration is disabled."
) )
else: else:
image = self.gfpgan.process(image, strength, seed) image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer": if facetool == "codeformer":
if self.codeformer is None: if self.codeformer is None:
logger.info( self.logger.info(
"CodeFormer not found. Face restoration is disabled." "CodeFormer not found. Face restoration is disabled."
) )
else: else:
@ -81,7 +83,7 @@ class RestorationServices:
fidelity=codeformer_fidelity, fidelity=codeformer_fidelity,
) )
else: else:
logger.info("Face Restoration is disabled.") self.logger.info("Face Restoration is disabled.")
if upscale is not None: if upscale is not None:
if self.esrgan is not None: if self.esrgan is not None:
if len(upscale) < 2: if len(upscale) < 2:
@ -94,9 +96,9 @@ class RestorationServices:
denoise_str=upscale_denoise_str, denoise_str=upscale_denoise_str,
) )
else: else:
logger.info("ESRGAN is disabled. Image not upscaled.") self.logger.info("ESRGAN is disabled. Image not upscaled.")
except Exception as e: except Exception as e:
logger.info( self.logger.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}" f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
) )

View File

@ -18,7 +18,7 @@ import warnings
from enum import Enum, auto from enum import Enum, auto
from pathlib import Path from pathlib import Path
from shutil import move, rmtree from shutil import move, rmtree
from typing import Any, Optional, Union, Callable from typing import Any, Optional, Union, Callable, types
import safetensors import safetensors
import safetensors.torch import safetensors.torch
@ -76,6 +76,8 @@ class ModelManager(object):
Model manager handles loading, caching, importing, deleting, converting, and editing models. Model manager handles loading, caching, importing, deleting, converting, and editing models.
""" """
logger: types.ModuleType = logger
def __init__( def __init__(
self, self,
config: OmegaConf | Path, config: OmegaConf | Path,
@ -84,6 +86,7 @@ class ModelManager(object):
max_loaded_models=DEFAULT_MAX_MODELS, max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False, sequential_offload=False,
embedding_path: Path = None, embedding_path: Path = None,
logger: types.ModuleType = logger,
): ):
""" """
Initialize with the path to the models.yaml config file or Initialize with the path to the models.yaml config file or
@ -105,6 +108,7 @@ class ModelManager(object):
self.current_model = None self.current_model = None
self.sequential_offload = sequential_offload self.sequential_offload = sequential_offload
self.embedding_path = embedding_path self.embedding_path = embedding_path
self.logger = logger
def valid_model(self, model_name: str) -> bool: def valid_model(self, model_name: str) -> bool:
""" """
@ -133,7 +137,7 @@ class ModelManager(object):
) )
if not self.valid_model(model_name): if not self.valid_model(model_name):
logger.error( self.logger.error(
f'"{model_name}" is not a known model name. Please check your models.yaml file' f'"{model_name}" is not a known model name. Please check your models.yaml file'
) )
return self.current_model return self.current_model
@ -145,7 +149,7 @@ class ModelManager(object):
if model_name in self.models: if model_name in self.models:
requested_model = self.models[model_name]["model"] requested_model = self.models[model_name]["model"]
logger.info(f"Retrieving model {model_name} from system RAM cache") self.logger.info(f"Retrieving model {model_name} from system RAM cache")
requested_model.ready() requested_model.ready()
width = self.models[model_name]["width"] width = self.models[model_name]["width"]
height = self.models[model_name]["height"] height = self.models[model_name]["height"]
@ -380,7 +384,7 @@ class ModelManager(object):
""" """
omega = self.config omega = self.config
if model_name not in omega: if model_name not in omega:
logger.error(f"Unknown model {model_name}") self.logger.error(f"Unknown model {model_name}")
return return
# save these for use in deletion later # save these for use in deletion later
conf = omega[model_name] conf = omega[model_name]
@ -393,13 +397,13 @@ class ModelManager(object):
self.stack.remove(model_name) self.stack.remove(model_name)
if delete_files: if delete_files:
if weights: if weights:
logger.info(f"Deleting file {weights}") self.logger.info(f"Deleting file {weights}")
Path(weights).unlink(missing_ok=True) Path(weights).unlink(missing_ok=True)
elif path: elif path:
logger.info(f"Deleting directory {path}") self.logger.info(f"Deleting directory {path}")
rmtree(path, ignore_errors=True) rmtree(path, ignore_errors=True)
elif repo_id: elif repo_id:
logger.info(f"Deleting the cached model directory for {repo_id}") self.logger.info(f"Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id) self._delete_model_from_cache(repo_id)
def add_model( def add_model(
@ -440,7 +444,7 @@ class ModelManager(object):
def _load_model(self, model_name: str): def _load_model(self, model_name: str):
"""Load and initialize the model from configuration variables passed at object creation time""" """Load and initialize the model from configuration variables passed at object creation time"""
if model_name not in self.config: if model_name not in self.config:
logger.error( self.logger.error(
f'"{model_name}" is not a known model name. Please check your models.yaml file' f'"{model_name}" is not a known model name. Please check your models.yaml file'
) )
return return
@ -458,7 +462,7 @@ class ModelManager(object):
model_format = mconfig.get("format", "ckpt") model_format = mconfig.get("format", "ckpt")
if model_format == "ckpt": if model_format == "ckpt":
weights = mconfig.weights weights = mconfig.weights
logger.info(f"Loading {model_name} from {weights}") self.logger.info(f"Loading {model_name} from {weights}")
model, width, height, model_hash = self._load_ckpt_model( model, width, height, model_hash = self._load_ckpt_model(
model_name, mconfig model_name, mconfig
) )
@ -474,13 +478,13 @@ class ModelManager(object):
# usage statistics # usage statistics
toc = time.time() toc = time.time()
logger.info("Model loaded in " + "%4.2fs" % (toc - tic)) self.logger.info("Model loaded in " + "%4.2fs" % (toc - tic))
if self._has_cuda(): if self._has_cuda():
logger.info( self.logger.info(
"Max VRAM used to load the model: "+ "Max VRAM used to load the model: "+
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9) "%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
) )
logger.info( self.logger.info(
"Current VRAM usage: "+ "Current VRAM usage: "+
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9) "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
) )
@ -490,11 +494,11 @@ class ModelManager(object):
name_or_path = self.model_name_or_path(mconfig) name_or_path = self.model_name_or_path(mconfig)
using_fp16 = self.precision == "float16" using_fp16 = self.precision == "float16"
logger.info(f"Loading diffusers model from {name_or_path}") self.logger.info(f"Loading diffusers model from {name_or_path}")
if using_fp16: if using_fp16:
logger.debug("Using faster float16 precision") self.logger.debug("Using faster float16 precision")
else: else:
logger.debug("Using more accurate float32 precision") self.logger.debug("Using more accurate float32 precision")
# TODO: scan weights maybe? # TODO: scan weights maybe?
pipeline_args: dict[str, Any] = dict( pipeline_args: dict[str, Any] = dict(
@ -526,7 +530,7 @@ class ModelManager(object):
if str(e).startswith("fp16 is not a valid"): if str(e).startswith("fp16 is not a valid"):
pass pass
else: else:
logger.error( self.logger.error(
f"An unexpected error occurred while downloading the model: {e})" f"An unexpected error occurred while downloading the model: {e})"
) )
if pipeline: if pipeline:
@ -545,7 +549,7 @@ class ModelManager(object):
# square images??? # square images???
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
height = width height = width
logger.debug(f"Default image dimensions = {width} x {height}") self.logger.debug(f"Default image dimensions = {width} x {height}")
return pipeline, width, height, model_hash return pipeline, width, height, model_hash
@ -562,7 +566,7 @@ class ModelManager(object):
weights = os.path.normpath(os.path.join(Globals.root, weights)) weights = os.path.normpath(os.path.join(Globals.root, weights))
# Convert to diffusers and return a diffusers pipeline # Convert to diffusers and return a diffusers pipeline
logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...") self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
from . import load_pipeline_from_original_stable_diffusion_ckpt from . import load_pipeline_from_original_stable_diffusion_ckpt
@ -627,7 +631,7 @@ class ModelManager(object):
if model_name not in self.models: if model_name not in self.models:
return return
logger.info(f"Offloading {model_name} to CPU") self.logger.info(f"Offloading {model_name} to CPU")
model = self.models[model_name]["model"] model = self.models[model_name]["model"]
model.offload_all() model.offload_all()
self.current_model = None self.current_model = None
@ -643,26 +647,26 @@ class ModelManager(object):
and option to exit if an infected file is identified. and option to exit if an infected file is identified.
""" """
# scan model # scan model
logger.debug(f"Scanning Model: {model_name}") self.logger.debug(f"Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint) scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0: if scan_result.infected_files != 0:
if scan_result.infected_files == 1: if scan_result.infected_files == 1:
logger.critical(f"Issues Found In Model: {scan_result.issues_count}") self.logger.critical(f"Issues Found In Model: {scan_result.issues_count}")
logger.critical("The model you are trying to load seems to be infected.") self.logger.critical("The model you are trying to load seems to be infected.")
logger.critical("For your safety, InvokeAI will not load this model.") self.logger.critical("For your safety, InvokeAI will not load this model.")
logger.critical("Please use checkpoints from trusted sources.") self.logger.critical("Please use checkpoints from trusted sources.")
logger.critical("Exiting InvokeAI") self.logger.critical("Exiting InvokeAI")
sys.exit() sys.exit()
else: else:
logger.warning("InvokeAI was unable to scan the model you are using.") self.logger.warning("InvokeAI was unable to scan the model you are using.")
model_safe_check_fail = ask_user( model_safe_check_fail = ask_user(
"Do you want to to continue loading the model?", ["y", "n"] "Do you want to to continue loading the model?", ["y", "n"]
) )
if model_safe_check_fail.lower() != "y": if model_safe_check_fail.lower() != "y":
logger.critical("Exiting InvokeAI") self.logger.critical("Exiting InvokeAI")
sys.exit() sys.exit()
else: else:
logger.debug("Model scanned ok") self.logger.debug("Model scanned ok")
def import_diffuser_model( def import_diffuser_model(
self, self,
@ -779,24 +783,24 @@ class ModelManager(object):
model_path: Path = None model_path: Path = None
thing = path_url_or_repo # to save typing thing = path_url_or_repo # to save typing
logger.info(f"Probing {thing} for import") self.logger.info(f"Probing {thing} for import")
if thing.startswith(("http:", "https:", "ftp:")): if thing.startswith(("http:", "https:", "ftp:")):
logger.info(f"{thing} appears to be a URL") self.logger.info(f"{thing} appears to be a URL")
model_path = self._resolve_path( model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1" thing, "models/ldm/stable-diffusion-v1"
) # _resolve_path does a download if needed ) # _resolve_path does a download if needed
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")): elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
if Path(thing).stem in ["model", "diffusion_pytorch_model"]: if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import") self.logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
return return
else: else:
logger.debug(f"{thing} appears to be a checkpoint file on disk") self.logger.debug(f"{thing} appears to be a checkpoint file on disk")
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1") model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists(): elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
logger.debug(f"{thing} appears to be a diffusers file on disk") self.logger.debug(f"{thing} appears to be a diffusers file on disk")
model_name = self.import_diffuser_model( model_name = self.import_diffuser_model(
thing, thing,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"), vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
@ -807,30 +811,30 @@ class ModelManager(object):
elif Path(thing).is_dir(): elif Path(thing).is_dir():
if (Path(thing) / "model_index.json").exists(): if (Path(thing) / "model_index.json").exists():
logger.debug(f"{thing} appears to be a diffusers model.") self.logger.debug(f"{thing} appears to be a diffusers model.")
model_name = self.import_diffuser_model( model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf thing, commit_to_conf=commit_to_conf
) )
else: else:
logger.debug(f"{thing} appears to be a directory. Will scan for models to import") self.logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
for m in list(Path(thing).rglob("*.ckpt")) + list( for m in list(Path(thing).rglob("*.ckpt")) + list(
Path(thing).rglob("*.safetensors") Path(thing).rglob("*.safetensors")
): ):
if model_name := self.heuristic_import( if model_name := self.heuristic_import(
str(m), commit_to_conf=commit_to_conf str(m), commit_to_conf=commit_to_conf
): ):
logger.info(f"{model_name} successfully imported") self.logger.info(f"{model_name} successfully imported")
return model_name return model_name
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing): elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id") self.logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
model_name = self.import_diffuser_model( model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf thing, commit_to_conf=commit_to_conf
) )
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name]) pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
return model_name return model_name
else: else:
logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id") self.logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
# Model_path is set in the event of a legacy checkpoint file. # Model_path is set in the event of a legacy checkpoint file.
# If not set, we're all done # If not set, we're all done
@ -838,7 +842,7 @@ class ModelManager(object):
return return
if model_path.stem in self.config: # already imported if model_path.stem in self.config: # already imported
logger.debug("Already imported. Skipping") self.logger.debug("Already imported. Skipping")
return model_path.stem return model_path.stem
# another round of heuristics to guess the correct config file. # another round of heuristics to guess the correct config file.
@ -854,38 +858,38 @@ class ModelManager(object):
# look for a like-named .yaml file in same directory # look for a like-named .yaml file in same directory
if model_path.with_suffix(".yaml").exists(): if model_path.with_suffix(".yaml").exists():
model_config_file = model_path.with_suffix(".yaml") model_config_file = model_path.with_suffix(".yaml")
logger.debug(f"Using config file {model_config_file.name}") self.logger.debug(f"Using config file {model_config_file.name}")
else: else:
model_type = self.probe_model_type(checkpoint) model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1: if model_type == SDLegacyType.V1:
logger.debug("SD-v1 model detected") self.logger.debug("SD-v1 model detected")
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inference.yaml" Globals.root, "configs/stable-diffusion/v1-inference.yaml"
) )
elif model_type == SDLegacyType.V1_INPAINT: elif model_type == SDLegacyType.V1_INPAINT:
logger.debug("SD-v1 inpainting model detected") self.logger.debug("SD-v1 inpainting model detected")
model_config_file = Path( model_config_file = Path(
Globals.root, Globals.root,
"configs/stable-diffusion/v1-inpainting-inference.yaml", "configs/stable-diffusion/v1-inpainting-inference.yaml",
) )
elif model_type == SDLegacyType.V2_v: elif model_type == SDLegacyType.V2_v:
logger.debug("SD-v2-v model detected") self.logger.debug("SD-v2-v model detected")
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
) )
elif model_type == SDLegacyType.V2_e: elif model_type == SDLegacyType.V2_e:
logger.debug("SD-v2-e model detected") self.logger.debug("SD-v2-e model detected")
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml" Globals.root, "configs/stable-diffusion/v2-inference.yaml"
) )
elif model_type == SDLegacyType.V2: elif model_type == SDLegacyType.V2:
logger.warning( self.logger.warning(
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path." f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
) )
return return
else: else:
logger.warning( self.logger.warning(
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path." f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
) )
return return
@ -902,7 +906,7 @@ class ModelManager(object):
for suffix in ["pt", "ckpt", "safetensors"]: for suffix in ["pt", "ckpt", "safetensors"]:
if (model_path.with_suffix(f".vae.{suffix}")).exists(): if (model_path.with_suffix(f".vae.{suffix}")).exists():
vae_path = model_path.with_suffix(f".vae.{suffix}") vae_path = model_path.with_suffix(f".vae.{suffix}")
logger.debug(f"Using VAE file {vae_path.name}") self.logger.debug(f"Using VAE file {vae_path.name}")
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse") vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
diffuser_path = Path( diffuser_path = Path(
@ -948,14 +952,14 @@ class ModelManager(object):
from . import convert_ckpt_to_diffusers from . import convert_ckpt_to_diffusers
if diffusers_path.exists(): if diffusers_path.exists():
logger.error( self.logger.error(
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again." f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
) )
return return
model_name = model_name or diffusers_path.name model_name = model_name or diffusers_path.name
model_description = model_description or f"Converted version of {model_name}" model_description = model_description or f"Converted version of {model_name}"
logger.debug(f"Converting {model_name} to diffusers (30-60s)") self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
try: try:
# By passing the specified VAE to the conversion function, the autoencoder # By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file # will be built into the model rather than tacked on afterward via the config file
@ -972,10 +976,10 @@ class ModelManager(object):
vae_path=vae_path, vae_path=vae_path,
scan_needed=scan_needed, scan_needed=scan_needed,
) )
logger.debug( self.logger.debug(
f"Success. Converted model is now located at {str(diffusers_path)}" f"Success. Converted model is now located at {str(diffusers_path)}"
) )
logger.debug(f"Writing new config file entry for {model_name}") self.logger.debug(f"Writing new config file entry for {model_name}")
new_config = dict( new_config = dict(
path=str(diffusers_path), path=str(diffusers_path),
description=model_description, description=model_description,
@ -986,17 +990,17 @@ class ModelManager(object):
self.add_model(model_name, new_config, True) self.add_model(model_name, new_config, True)
if commit_to_conf: if commit_to_conf:
self.commit(commit_to_conf) self.commit(commit_to_conf)
logger.debug("Conversion succeeded") self.logger.debug("Conversion succeeded")
except Exception as e: except Exception as e:
logger.warning(f"Conversion failed: {str(e)}") self.logger.warning(f"Conversion failed: {str(e)}")
logger.warning( self.logger.warning(
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)" "If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
) )
return model_name return model_name
def search_models(self, search_folder): def search_models(self, search_folder):
logger.info(f"Finding Models In: {search_folder}") self.logger.info(f"Finding Models In: {search_folder}")
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt") models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors") models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
@ -1020,7 +1024,7 @@ class ModelManager(object):
num_loaded_models = len(self.models) num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models: if num_loaded_models >= self.max_loaded_models:
least_recent_model = self._pop_oldest_model() least_recent_model = self._pop_oldest_model()
logger.info( self.logger.info(
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}" f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
) )
if least_recent_model is not None: if least_recent_model is not None:
@ -1029,7 +1033,7 @@ class ModelManager(object):
def print_vram_usage(self) -> None: def print_vram_usage(self) -> None:
if self._has_cuda: if self._has_cuda:
logger.info( self.logger.info(
"Current VRAM usage:"+ "Current VRAM usage:"+
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9), "%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
) )
@ -1119,7 +1123,7 @@ class ModelManager(object):
dest = hub / model.stem dest = hub / model.stem
if dest.exists() and not source.exists(): if dest.exists() and not source.exists():
continue continue
logger.info(f"{source} => {dest}") cls.logger.info(f"{source} => {dest}")
if source.exists(): if source.exists():
if dest.is_symlink(): if dest.is_symlink():
logger.warning(f"Found symlink at {dest.name}. Not migrating.") logger.warning(f"Found symlink at {dest.name}. Not migrating.")
@ -1139,7 +1143,7 @@ class ModelManager(object):
] ]
for d in empty: for d in empty:
os.rmdir(d) os.rmdir(d)
logger.info("Migration is done. Continuing...") cls.logger.info("Migration is done. Continuing...")
def _resolve_path( def _resolve_path(
self, source: Union[str, Path], dest_directory: str self, source: Union[str, Path], dest_directory: str
@ -1182,14 +1186,14 @@ class ModelManager(object):
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline): def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
if self.embedding_path is not None: if self.embedding_path is not None:
logger.info(f"Loading embeddings from {self.embedding_path}") self.logger.info(f"Loading embeddings from {self.embedding_path}")
for root, _, files in os.walk(self.embedding_path): for root, _, files in os.walk(self.embedding_path):
for name in files: for name in files:
ti_path = os.path.join(root, name) ti_path = os.path.join(root, name)
model.textual_inversion_manager.load_textual_inversion( model.textual_inversion_manager.load_textual_inversion(
ti_path, defer_injecting_tokens=True ti_path, defer_injecting_tokens=True
) )
logger.info( self.logger.info(
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}' f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
) )
@ -1212,7 +1216,7 @@ class ModelManager(object):
with open(hashpath) as f: with open(hashpath) as f:
hash = f.read() hash = f.read()
return hash return hash
logger.debug("Calculating sha256 hash of model files") self.logger.debug("Calculating sha256 hash of model files")
tic = time.time() tic = time.time()
sha = hashlib.sha256() sha = hashlib.sha256()
count = 0 count = 0
@ -1224,7 +1228,7 @@ class ModelManager(object):
sha.update(chunk) sha.update(chunk)
hash = sha.hexdigest() hash = sha.hexdigest()
toc = time.time() toc = time.time()
logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic)) self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
with open(hashpath, "w") as f: with open(hashpath, "w") as f:
f.write(hash) f.write(hash)
return hash return hash
@ -1242,13 +1246,13 @@ class ModelManager(object):
hash = f.read() hash = f.read()
return hash return hash
logger.debug("Calculating sha256 hash of weights file") self.logger.debug("Calculating sha256 hash of weights file")
tic = time.time() tic = time.time()
sha = hashlib.sha256() sha = hashlib.sha256()
sha.update(data) sha.update(data)
hash = sha.hexdigest() hash = sha.hexdigest()
toc = time.time() toc = time.time()
logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic)) self.logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
with open(hashpath, "w") as f: with open(hashpath, "w") as f:
f.write(hash) f.write(hash)
@ -1269,12 +1273,12 @@ class ModelManager(object):
local_files_only=not Globals.internet_available, local_files_only=not Globals.internet_available,
) )
logger.debug(f"Loading diffusers VAE from {name_or_path}") self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
if using_fp16: if using_fp16:
vae_args.update(torch_dtype=torch.float16) vae_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}] fp_args_list = [{"revision": "fp16"}, {}]
else: else:
logger.debug("Using more accurate float32 precision") self.logger.debug("Using more accurate float32 precision")
fp_args_list = [{}] fp_args_list = [{}]
vae = None vae = None
@ -1298,12 +1302,12 @@ class ModelManager(object):
break break
if not vae and deferred_error: if not vae and deferred_error:
logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}") self.logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
return vae return vae
@staticmethod @classmethod
def _delete_model_from_cache(repo_id): def _delete_model_from_cache(cls,repo_id):
cache_info = scan_cache_dir(global_cache_dir("hub")) cache_info = scan_cache_dir(global_cache_dir("hub"))
# I'm sure there is a way to do this with comprehensions # I'm sure there is a way to do this with comprehensions
@ -1314,7 +1318,7 @@ class ModelManager(object):
for revision in repo.revisions: for revision in repo.revisions:
hashes_to_delete.add(revision.commit_hash) hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete) strategy = cache_info.delete_revisions(*hashes_to_delete)
logger.warning( cls.logger.warning(
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}" f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
) )
strategy.execute() strategy.execute()