prompt user for prediction type when autoimporting a v2 model without .yaml file

don't ask user for prediction type of a config.yaml provided
This commit is contained in:
Lincoln Stein 2023-06-26 16:18:16 -04:00
parent f67dec7f0c
commit 823e098b7c
4 changed files with 10 additions and 6 deletions

View File

@ -411,7 +411,6 @@ def update_autoimport_dir(autodir: Path):
outfile.write(yaml)
tmpfile.replace(invokeai_config_path)
# -------------------------------------
def yes_or_no(prompt: str, default_yes=True):
default = "y" if default_yes else "n"

View File

@ -712,8 +712,12 @@ class ModelManager(object):
'''
# avoid circular import
from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
installer = ModelInstall(config = self.app_config,
model_manager = self)
model_manager = self,
prediction_type_helper = ask_user_for_prediction_type,
)
installed = set()
if not self.app_config.autoimport_dir:

View File

@ -255,7 +255,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
if self.checkpoint_path and self.helper:
if self.checkpoint_path and self.helper \
and not self.checkpoint_path.with_suffix('.yaml').exists(): # if a .yaml config file exists, then this step not needed
return self.helper(self.checkpoint_path)
else:
return None

View File

@ -578,14 +578,14 @@ class StderrToMessage():
# --------------------------------------------------------
def ask_user_for_prediction_type(model_path: Path,
tui_conn: Connection=None
)->Path:
)->SchedulerPredictionType:
if tui_conn:
logger.debug('Waiting for user response...')
return _ask_user_for_pt_tui(model_path, tui_conn)
else:
return _ask_user_for_pt_cmdline(model_path)
def _ask_user_for_pt_cmdline(model_path):
def _ask_user_for_pt_cmdline(model_path: Path)->SchedulerPredictionType:
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
print(
f"""
@ -608,7 +608,7 @@ Please select the type of the V2 checkpoint named {model_path.name}:
return
return choice
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->Path:
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->SchedulerPredictionType:
try:
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
# note that we don't do any status checking here