adjust non-app modules to use new config system

This commit is contained in:
Lincoln Stein
2023-05-04 00:43:51 -04:00
parent 15ffb53e59
commit e4196bbe5b
18 changed files with 84 additions and 98 deletions

View File

@ -8,7 +8,6 @@ import argparse
import curses
import os
import sys
import traceback
import warnings
from argparse import Namespace
from pathlib import Path
@ -20,20 +19,13 @@ from diffusers import logging as dlogging
from npyscreen import widget
from omegaconf import OmegaConf
from ...backend.globals import (
Globals,
global_cache_dir,
global_config_file,
global_models_dir,
global_set_root,
)
import invokeai.backend.util.logging as logger
from invokeai.services.config import get_invokeai_config
from ...backend.model_management import ModelManager
from ...frontend.install.widgets import FloatTitleSlider
DEST_MERGED_MODEL_DIR = "merged_models"
config = get_invokeai_config()
def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]],
@ -60,7 +52,7 @@ def merge_diffusion_models(
pipe = DiffusionPipeline.from_pretrained(
model_ids_or_paths[0],
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
cache_dir=kwargs.get("cache_dir", config.cache_dir),
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
@ -94,7 +86,7 @@ def merge_diffusion_models_and_commit(
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
config_file = global_config_file()
config_file = config.model_conf_path
model_manager = ModelManager(OmegaConf.load(config_file))
for mod in models:
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
@ -106,7 +98,7 @@ def merge_diffusion_models_and_commit(
merged_pipe = merge_diffusion_models(
model_ids_or_paths, alpha, interp, force, **kwargs
)
dump_path = global_models_dir() / DEST_MERGED_MODEL_DIR
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
os.makedirs(dump_path, exist_ok=True)
dump_path = dump_path / merged_model_name
@ -126,7 +118,7 @@ def _parse_args() -> Namespace:
parser.add_argument(
"--root_dir",
type=Path,
default=Globals.root,
default=config.root,
help="Path to the invokeai runtime directory",
)
parser.add_argument(
@ -398,7 +390,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self):
super().__init__()
conf = OmegaConf.load(global_config_file())
conf = OmegaConf.load(config.model_conf_path)
self.model_manager = ModelManager(
conf, "cpu", "float16"
) # precision doesn't really matter here
@ -429,7 +421,7 @@ def run_cli(args: Namespace):
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
)
model_manager = ModelManager(OmegaConf.load(global_config_file()))
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
assert (
args.clobber or args.merged_model_name not in model_manager.model_names()
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
@ -440,9 +432,9 @@ def run_cli(args: Namespace):
def main():
args = _parse_args()
global_set_root(args.root_dir)
config.root = args.root_dir
cache_dir = str(global_cache_dir("hub"))
cache_dir = config.cache_dir
os.environ[
"HF_HOME"
] = cache_dir # because not clear the merge pipeline is honoring cache_dir