mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
logger is a interchangeable service
This commit is contained in:
parent
8db20e0d95
commit
974841926d
@ -67,8 +67,9 @@ class ApiDependencies:
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config),
|
||||
model_manager=get_model_manager(config,logger),
|
||||
events=events,
|
||||
logger=logger,
|
||||
latents=latents,
|
||||
images=images,
|
||||
metadata=metadata,
|
||||
@ -80,7 +81,7 @@ class ApiDependencies:
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
restoration=RestorationServices(config,logger),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
@ -4,15 +4,10 @@ import shutil
|
||||
import asyncio
|
||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from pathlib import Path
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
|
||||
from invokeai.backend.args import Args
|
||||
|
||||
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
@ -113,6 +108,7 @@ async def update_model(
|
||||
async def delete_model(model_name: str) -> None:
|
||||
"""Delete Model"""
|
||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_exists = model_name in model_names
|
||||
|
||||
# check if model exists
|
||||
|
@ -181,7 +181,7 @@ def invoke_all(context: CliContext):
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
logger.error(
|
||||
context.invoker.services.logger.error(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
@ -191,7 +191,7 @@ def invoke_all(context: CliContext):
|
||||
def invoke_cli():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
model_manager = get_model_manager(config)
|
||||
model_manager = get_model_manager(config,logger=logger)
|
||||
|
||||
# This initializes the autocompleter and returns it.
|
||||
# Currently nothing is done with the returned Completer
|
||||
@ -224,7 +224,8 @@ def invoke_cli():
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
restoration=RestorationServices(config,logger=logger),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
@ -364,12 +365,12 @@ def invoke_cli():
|
||||
invoke_all(context)
|
||||
|
||||
except InvalidArgs:
|
||||
logger.warning('Invalid command, use "help" to list commands')
|
||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
logger.warning("Session error: creating a new session")
|
||||
invoker.services.logger.warning("Session error: creating a new session")
|
||||
context.reset()
|
||||
|
||||
except ExitCli:
|
||||
|
@ -1,9 +1,9 @@
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_management.model_manager import ModelManager
|
||||
|
||||
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
logger = model_manager.logger
|
||||
if model_manager.valid_model(model_name):
|
||||
model = model_manager.get_model(model_name)
|
||||
else:
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from typing import types
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.backend import ModelManager
|
||||
|
||||
@ -29,6 +31,7 @@ class InvocationServices:
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
logger: types.ModuleType,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
@ -40,6 +43,7 @@ class InvocationServices:
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.metadata = metadata
|
||||
|
@ -5,20 +5,20 @@ from argparse import Namespace
|
||||
from invokeai.backend import Args
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import types
|
||||
|
||||
import invokeai.version
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...backend import ModelManager
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# TODO: Replace with an abstract class base ModelManagerBase
|
||||
def get_model_manager(config: Args) -> ModelManager:
|
||||
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
||||
if not config.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
|
||||
)
|
||||
|
||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
@ -63,9 +63,10 @@ def get_model_manager(config: Args) -> ModelManager:
|
||||
device_type=device,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
embedding_path = Path(embedding_path),
|
||||
logger = logger,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(config, e)
|
||||
report_model_error(config, e, logger)
|
||||
except (IOError, KeyError) as e:
|
||||
logger.error(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
@ -77,17 +78,17 @@ def get_model_manager(config: Args) -> ModelManager:
|
||||
conf_path=config.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
|
||||
logger.info('Model manager initialized')
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception):
|
||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
logger.error(
|
||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
logger.warning
|
||||
logger.warning(
|
||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
@ -103,7 +104,6 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
previous_config = sys.argv
|
||||
sys.argv = ["invokeai-configure"]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config.to_dict())
|
||||
|
@ -1,7 +1,7 @@
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
import invokeai.backend.util.logging as logger
|
||||
from typing import types
|
||||
from ...backend.restoration import Restoration
|
||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
|
||||
@ -11,7 +11,7 @@ from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
class RestorationServices:
|
||||
'''Face restoration and upscaling'''
|
||||
|
||||
def __init__(self,args):
|
||||
def __init__(self,args,logger:types.ModuleType):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if args.restore or args.esrgan:
|
||||
@ -35,6 +35,8 @@ class RestorationServices:
|
||||
self.gfpgan = gfpgan
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.logger = logger
|
||||
self.logger.info('Face restoration initialized')
|
||||
|
||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||
# esrgan upscaling
|
||||
@ -59,14 +61,14 @@ class RestorationServices:
|
||||
if self.gfpgan is not None or self.codeformer is not None:
|
||||
if facetool == "gfpgan":
|
||||
if self.gfpgan is None:
|
||||
logger.info(
|
||||
self.logger.info(
|
||||
"GFPGAN not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
if facetool == "codeformer":
|
||||
if self.codeformer is None:
|
||||
logger.info(
|
||||
self.logger.info(
|
||||
"CodeFormer not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
@ -81,7 +83,7 @@ class RestorationServices:
|
||||
fidelity=codeformer_fidelity,
|
||||
)
|
||||
else:
|
||||
logger.info("Face Restoration is disabled.")
|
||||
self.logger.info("Face Restoration is disabled.")
|
||||
if upscale is not None:
|
||||
if self.esrgan is not None:
|
||||
if len(upscale) < 2:
|
||||
@ -94,9 +96,9 @@ class RestorationServices:
|
||||
denoise_str=upscale_denoise_str,
|
||||
)
|
||||
else:
|
||||
logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
self.logger.info(
|
||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
)
|
||||
|
||||
|
@ -18,7 +18,7 @@ import warnings
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union, Callable
|
||||
from typing import Any, Optional, Union, Callable, types
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
@ -76,6 +76,8 @@ class ModelManager(object):
|
||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||
"""
|
||||
|
||||
logger: types.ModuleType = logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf | Path,
|
||||
@ -84,6 +86,7 @@ class ModelManager(object):
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
embedding_path: Path = None,
|
||||
logger: types.ModuleType = logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file or
|
||||
@ -105,6 +108,7 @@ class ModelManager(object):
|
||||
self.current_model = None
|
||||
self.sequential_offload = sequential_offload
|
||||
self.embedding_path = embedding_path
|
||||
self.logger = logger
|
||||
|
||||
def valid_model(self, model_name: str) -> bool:
|
||||
"""
|
||||
@ -133,7 +137,7 @@ class ModelManager(object):
|
||||
)
|
||||
|
||||
if not self.valid_model(model_name):
|
||||
logger.error(
|
||||
self.logger.error(
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
return self.current_model
|
||||
@ -145,7 +149,7 @@ class ModelManager(object):
|
||||
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]["model"]
|
||||
logger.info(f"Retrieving model {model_name} from system RAM cache")
|
||||
self.logger.info(f"Retrieving model {model_name} from system RAM cache")
|
||||
requested_model.ready()
|
||||
width = self.models[model_name]["width"]
|
||||
height = self.models[model_name]["height"]
|
||||
@ -380,7 +384,7 @@ class ModelManager(object):
|
||||
"""
|
||||
omega = self.config
|
||||
if model_name not in omega:
|
||||
logger.error(f"Unknown model {model_name}")
|
||||
self.logger.error(f"Unknown model {model_name}")
|
||||
return
|
||||
# save these for use in deletion later
|
||||
conf = omega[model_name]
|
||||
@ -393,13 +397,13 @@ class ModelManager(object):
|
||||
self.stack.remove(model_name)
|
||||
if delete_files:
|
||||
if weights:
|
||||
logger.info(f"Deleting file {weights}")
|
||||
self.logger.info(f"Deleting file {weights}")
|
||||
Path(weights).unlink(missing_ok=True)
|
||||
elif path:
|
||||
logger.info(f"Deleting directory {path}")
|
||||
self.logger.info(f"Deleting directory {path}")
|
||||
rmtree(path, ignore_errors=True)
|
||||
elif repo_id:
|
||||
logger.info(f"Deleting the cached model directory for {repo_id}")
|
||||
self.logger.info(f"Deleting the cached model directory for {repo_id}")
|
||||
self._delete_model_from_cache(repo_id)
|
||||
|
||||
def add_model(
|
||||
@ -440,7 +444,7 @@ class ModelManager(object):
|
||||
def _load_model(self, model_name: str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if model_name not in self.config:
|
||||
logger.error(
|
||||
self.logger.error(
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
return
|
||||
@ -458,7 +462,7 @@ class ModelManager(object):
|
||||
model_format = mconfig.get("format", "ckpt")
|
||||
if model_format == "ckpt":
|
||||
weights = mconfig.weights
|
||||
logger.info(f"Loading {model_name} from {weights}")
|
||||
self.logger.info(f"Loading {model_name} from {weights}")
|
||||
model, width, height, model_hash = self._load_ckpt_model(
|
||||
model_name, mconfig
|
||||
)
|
||||
@ -474,13 +478,13 @@ class ModelManager(object):
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
logger.info("Model loaded in " + "%4.2fs" % (toc - tic))
|
||||
self.logger.info("Model loaded in " + "%4.2fs" % (toc - tic))
|
||||
if self._has_cuda():
|
||||
logger.info(
|
||||
self.logger.info(
|
||||
"Max VRAM used to load the model: "+
|
||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
|
||||
)
|
||||
logger.info(
|
||||
self.logger.info(
|
||||
"Current VRAM usage: "+
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
||||
)
|
||||
@ -490,11 +494,11 @@ class ModelManager(object):
|
||||
name_or_path = self.model_name_or_path(mconfig)
|
||||
using_fp16 = self.precision == "float16"
|
||||
|
||||
logger.info(f"Loading diffusers model from {name_or_path}")
|
||||
self.logger.info(f"Loading diffusers model from {name_or_path}")
|
||||
if using_fp16:
|
||||
logger.debug("Using faster float16 precision")
|
||||
self.logger.debug("Using faster float16 precision")
|
||||
else:
|
||||
logger.debug("Using more accurate float32 precision")
|
||||
self.logger.debug("Using more accurate float32 precision")
|
||||
|
||||
# TODO: scan weights maybe?
|
||||
pipeline_args: dict[str, Any] = dict(
|
||||
@ -526,7 +530,7 @@ class ModelManager(object):
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
logger.error(
|
||||
self.logger.error(
|
||||
f"An unexpected error occurred while downloading the model: {e})"
|
||||
)
|
||||
if pipeline:
|
||||
@ -545,7 +549,7 @@ class ModelManager(object):
|
||||
# square images???
|
||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||
height = width
|
||||
logger.debug(f"Default image dimensions = {width} x {height}")
|
||||
self.logger.debug(f"Default image dimensions = {width} x {height}")
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
@ -562,7 +566,7 @@ class ModelManager(object):
|
||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||
|
||||
# Convert to diffusers and return a diffusers pipeline
|
||||
logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
@ -627,7 +631,7 @@ class ModelManager(object):
|
||||
if model_name not in self.models:
|
||||
return
|
||||
|
||||
logger.info(f"Offloading {model_name} to CPU")
|
||||
self.logger.info(f"Offloading {model_name} to CPU")
|
||||
model = self.models[model_name]["model"]
|
||||
model.offload_all()
|
||||
self.current_model = None
|
||||
@ -643,26 +647,26 @@ class ModelManager(object):
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
logger.debug(f"Scanning Model: {model_name}")
|
||||
self.logger.debug(f"Scanning Model: {model_name}")
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files == 1:
|
||||
logger.critical(f"Issues Found In Model: {scan_result.issues_count}")
|
||||
logger.critical("The model you are trying to load seems to be infected.")
|
||||
logger.critical("For your safety, InvokeAI will not load this model.")
|
||||
logger.critical("Please use checkpoints from trusted sources.")
|
||||
logger.critical("Exiting InvokeAI")
|
||||
self.logger.critical(f"Issues Found In Model: {scan_result.issues_count}")
|
||||
self.logger.critical("The model you are trying to load seems to be infected.")
|
||||
self.logger.critical("For your safety, InvokeAI will not load this model.")
|
||||
self.logger.critical("Please use checkpoints from trusted sources.")
|
||||
self.logger.critical("Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
logger.warning("InvokeAI was unable to scan the model you are using.")
|
||||
self.logger.warning("InvokeAI was unable to scan the model you are using.")
|
||||
model_safe_check_fail = ask_user(
|
||||
"Do you want to to continue loading the model?", ["y", "n"]
|
||||
)
|
||||
if model_safe_check_fail.lower() != "y":
|
||||
logger.critical("Exiting InvokeAI")
|
||||
self.logger.critical("Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
logger.debug("Model scanned ok")
|
||||
self.logger.debug("Model scanned ok")
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
@ -779,24 +783,24 @@ class ModelManager(object):
|
||||
model_path: Path = None
|
||||
thing = path_url_or_repo # to save typing
|
||||
|
||||
logger.info(f"Probing {thing} for import")
|
||||
self.logger.info(f"Probing {thing} for import")
|
||||
|
||||
if thing.startswith(("http:", "https:", "ftp:")):
|
||||
logger.info(f"{thing} appears to be a URL")
|
||||
self.logger.info(f"{thing} appears to be a URL")
|
||||
model_path = self._resolve_path(
|
||||
thing, "models/ldm/stable-diffusion-v1"
|
||||
) # _resolve_path does a download if needed
|
||||
|
||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||
logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
|
||||
self.logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
|
||||
return
|
||||
else:
|
||||
logger.debug(f"{thing} appears to be a checkpoint file on disk")
|
||||
self.logger.debug(f"{thing} appears to be a checkpoint file on disk")
|
||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||
|
||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||
logger.debug(f"{thing} appears to be a diffusers file on disk")
|
||||
self.logger.debug(f"{thing} appears to be a diffusers file on disk")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
@ -807,30 +811,30 @@ class ModelManager(object):
|
||||
|
||||
elif Path(thing).is_dir():
|
||||
if (Path(thing) / "model_index.json").exists():
|
||||
logger.debug(f"{thing} appears to be a diffusers model.")
|
||||
self.logger.debug(f"{thing} appears to be a diffusers model.")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
else:
|
||||
logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
|
||||
self.logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
|
||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
||||
Path(thing).rglob("*.safetensors")
|
||||
):
|
||||
if model_name := self.heuristic_import(
|
||||
str(m), commit_to_conf=commit_to_conf
|
||||
):
|
||||
logger.info(f"{model_name} successfully imported")
|
||||
self.logger.info(f"{model_name} successfully imported")
|
||||
return model_name
|
||||
|
||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
||||
logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
|
||||
self.logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
||||
return model_name
|
||||
else:
|
||||
logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
||||
self.logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
||||
|
||||
# Model_path is set in the event of a legacy checkpoint file.
|
||||
# If not set, we're all done
|
||||
@ -838,7 +842,7 @@ class ModelManager(object):
|
||||
return
|
||||
|
||||
if model_path.stem in self.config: # already imported
|
||||
logger.debug("Already imported. Skipping")
|
||||
self.logger.debug("Already imported. Skipping")
|
||||
return model_path.stem
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
@ -854,38 +858,38 @@ class ModelManager(object):
|
||||
# look for a like-named .yaml file in same directory
|
||||
if model_path.with_suffix(".yaml").exists():
|
||||
model_config_file = model_path.with_suffix(".yaml")
|
||||
logger.debug(f"Using config file {model_config_file.name}")
|
||||
self.logger.debug(f"Using config file {model_config_file.name}")
|
||||
|
||||
else:
|
||||
model_type = self.probe_model_type(checkpoint)
|
||||
if model_type == SDLegacyType.V1:
|
||||
logger.debug("SD-v1 model detected")
|
||||
self.logger.debug("SD-v1 model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
logger.debug("SD-v1 inpainting model detected")
|
||||
self.logger.debug("SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root,
|
||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_v:
|
||||
logger.debug("SD-v2-v model detected")
|
||||
self.logger.debug("SD-v2-v model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
logger.debug("SD-v2-e model detected")
|
||||
self.logger.debug("SD-v2-e model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2:
|
||||
logger.warning(
|
||||
self.logger.warning(
|
||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||
)
|
||||
return
|
||||
else:
|
||||
logger.warning(
|
||||
self.logger.warning(
|
||||
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||
)
|
||||
return
|
||||
@ -902,7 +906,7 @@ class ModelManager(object):
|
||||
for suffix in ["pt", "ckpt", "safetensors"]:
|
||||
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
||||
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
||||
logger.debug(f"Using VAE file {vae_path.name}")
|
||||
self.logger.debug(f"Using VAE file {vae_path.name}")
|
||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||
|
||||
diffuser_path = Path(
|
||||
@ -948,14 +952,14 @@ class ModelManager(object):
|
||||
from . import convert_ckpt_to_diffusers
|
||||
|
||||
if diffusers_path.exists():
|
||||
logger.error(
|
||||
self.logger.error(
|
||||
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
||||
)
|
||||
return
|
||||
|
||||
model_name = model_name or diffusers_path.name
|
||||
model_description = model_description or f"Converted version of {model_name}"
|
||||
logger.debug(f"Converting {model_name} to diffusers (30-60s)")
|
||||
self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
|
||||
try:
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
@ -972,10 +976,10 @@ class ModelManager(object):
|
||||
vae_path=vae_path,
|
||||
scan_needed=scan_needed,
|
||||
)
|
||||
logger.debug(
|
||||
self.logger.debug(
|
||||
f"Success. Converted model is now located at {str(diffusers_path)}"
|
||||
)
|
||||
logger.debug(f"Writing new config file entry for {model_name}")
|
||||
self.logger.debug(f"Writing new config file entry for {model_name}")
|
||||
new_config = dict(
|
||||
path=str(diffusers_path),
|
||||
description=model_description,
|
||||
@ -986,17 +990,17 @@ class ModelManager(object):
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
logger.debug("Conversion succeeded")
|
||||
self.logger.debug("Conversion succeeded")
|
||||
except Exception as e:
|
||||
logger.warning(f"Conversion failed: {str(e)}")
|
||||
logger.warning(
|
||||
self.logger.warning(f"Conversion failed: {str(e)}")
|
||||
self.logger.warning(
|
||||
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||
)
|
||||
|
||||
return model_name
|
||||
|
||||
def search_models(self, search_folder):
|
||||
logger.info(f"Finding Models In: {search_folder}")
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
||||
|
||||
@ -1020,7 +1024,7 @@ class ModelManager(object):
|
||||
num_loaded_models = len(self.models)
|
||||
if num_loaded_models >= self.max_loaded_models:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
logger.info(
|
||||
self.logger.info(
|
||||
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
||||
)
|
||||
if least_recent_model is not None:
|
||||
@ -1029,7 +1033,7 @@ class ModelManager(object):
|
||||
|
||||
def print_vram_usage(self) -> None:
|
||||
if self._has_cuda:
|
||||
logger.info(
|
||||
self.logger.info(
|
||||
"Current VRAM usage:"+
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
@ -1119,7 +1123,7 @@ class ModelManager(object):
|
||||
dest = hub / model.stem
|
||||
if dest.exists() and not source.exists():
|
||||
continue
|
||||
logger.info(f"{source} => {dest}")
|
||||
cls.logger.info(f"{source} => {dest}")
|
||||
if source.exists():
|
||||
if dest.is_symlink():
|
||||
logger.warning(f"Found symlink at {dest.name}. Not migrating.")
|
||||
@ -1139,7 +1143,7 @@ class ModelManager(object):
|
||||
]
|
||||
for d in empty:
|
||||
os.rmdir(d)
|
||||
logger.info("Migration is done. Continuing...")
|
||||
cls.logger.info("Migration is done. Continuing...")
|
||||
|
||||
def _resolve_path(
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
@ -1182,14 +1186,14 @@ class ModelManager(object):
|
||||
|
||||
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
||||
if self.embedding_path is not None:
|
||||
logger.info(f"Loading embeddings from {self.embedding_path}")
|
||||
self.logger.info(f"Loading embeddings from {self.embedding_path}")
|
||||
for root, _, files in os.walk(self.embedding_path):
|
||||
for name in files:
|
||||
ti_path = os.path.join(root, name)
|
||||
model.textual_inversion_manager.load_textual_inversion(
|
||||
ti_path, defer_injecting_tokens=True
|
||||
)
|
||||
logger.info(
|
||||
self.logger.info(
|
||||
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||
)
|
||||
|
||||
@ -1212,7 +1216,7 @@ class ModelManager(object):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
logger.debug("Calculating sha256 hash of model files")
|
||||
self.logger.debug("Calculating sha256 hash of model files")
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
count = 0
|
||||
@ -1224,7 +1228,7 @@ class ModelManager(object):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
@ -1242,13 +1246,13 @@ class ModelManager(object):
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
logger.debug("Calculating sha256 hash of weights file")
|
||||
self.logger.debug("Calculating sha256 hash of weights file")
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
sha.update(data)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
|
||||
self.logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
|
||||
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
@ -1269,12 +1273,12 @@ class ModelManager(object):
|
||||
local_files_only=not Globals.internet_available,
|
||||
)
|
||||
|
||||
logger.debug(f"Loading diffusers VAE from {name_or_path}")
|
||||
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
|
||||
if using_fp16:
|
||||
vae_args.update(torch_dtype=torch.float16)
|
||||
fp_args_list = [{"revision": "fp16"}, {}]
|
||||
else:
|
||||
logger.debug("Using more accurate float32 precision")
|
||||
self.logger.debug("Using more accurate float32 precision")
|
||||
fp_args_list = [{}]
|
||||
|
||||
vae = None
|
||||
@ -1298,12 +1302,12 @@ class ModelManager(object):
|
||||
break
|
||||
|
||||
if not vae and deferred_error:
|
||||
logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||
self.logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||
|
||||
return vae
|
||||
|
||||
@staticmethod
|
||||
def _delete_model_from_cache(repo_id):
|
||||
@classmethod
|
||||
def _delete_model_from_cache(cls,repo_id):
|
||||
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
||||
|
||||
# I'm sure there is a way to do this with comprehensions
|
||||
@ -1314,7 +1318,7 @@ class ModelManager(object):
|
||||
for revision in repo.revisions:
|
||||
hashes_to_delete.add(revision.commit_hash)
|
||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||
logger.warning(
|
||||
cls.logger.warning(
|
||||
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||
)
|
||||
strategy.execute()
|
||||
|
Loading…
Reference in New Issue
Block a user