ask user for v2 variant when model manager can't infer it

This commit is contained in:
Lincoln Stein
2023-06-04 11:27:44 -04:00
parent 31e97ead2a
commit 1a7fb601dc
5 changed files with 225 additions and 48 deletions

View File

@ -9,7 +9,7 @@ import warnings
from dataclasses import dataclass,field
from pathlib import Path
from tempfile import TemporaryFile
from typing import List, Dict
from typing import List, Dict, Callable
import requests
from diffusers import AutoencoderKL
@ -95,6 +95,7 @@ def install_requested_models(
precision: str = "float16",
purge_deleted: bool = False,
config_file_path: Path = None,
model_config_file_callback: Callable[[Path],Path] = None
):
"""
Entry point for installing/deleting starter models, or installing external models.
@ -118,19 +119,19 @@ def install_requested_models(
# TODO: Replace next three paragraphs with calls into new model manager
if diffusers.remove_models and len(diffusers.remove_models) > 0:
logger.info("DELETING UNCHECKED STARTER MODELS")
logger.info("Processing requested deletions")
for model in diffusers.remove_models:
logger.info(f"{model}...")
model_manager.del_model(model, delete_files=purge_deleted)
model_manager.commit(config_file_path)
if diffusers.install_models and len(diffusers.install_models) > 0:
logger.info("INSTALLING SELECTED STARTER MODELS")
logger.info("Installing requested models")
downloaded_paths = download_weight_datasets(
models=diffusers.install_models,
access_token=None,
precision=precision,
) # FIX: for historical reasons, we don't use model manager here
)
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
if len(successful) > 0:
update_config_file(successful, config_file_path)
@ -153,6 +154,7 @@ def install_requested_models(
model_manager.heuristic_import(
path_url_or_repo,
commit_to_conf=config_file_path,
config_file_callback = model_config_file_callback,
)
except KeyboardInterrupt:
sys.exit(-1)

View File

@ -874,14 +874,12 @@ class ModelManager(object):
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
elif model_type == SDLegacyType.V2:
self.logger.warning(
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined."
)
return
else:
self.logger.warning(
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model."
)
return
if not model_config_file and config_file_callback:
model_config_file = config_file_callback(model_path)