prompt user for prediction type when autoimporting a v2 model without .yaml file

don't ask user for prediction type of a config.yaml provided
This commit is contained in:
Lincoln Stein 2023-06-26 16:18:16 -04:00
parent f67dec7f0c
commit 823e098b7c
4 changed files with 10 additions and 6 deletions

View File

@ -411,7 +411,6 @@ def update_autoimport_dir(autodir: Path):
outfile.write(yaml) outfile.write(yaml)
tmpfile.replace(invokeai_config_path) tmpfile.replace(invokeai_config_path)
# ------------------------------------- # -------------------------------------
def yes_or_no(prompt: str, default_yes=True): def yes_or_no(prompt: str, default_yes=True):
default = "y" if default_yes else "n" default = "y" if default_yes else "n"

View File

@ -712,8 +712,12 @@ class ModelManager(object):
''' '''
# avoid circular import # avoid circular import
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
installer = ModelInstall(config = self.app_config, installer = ModelInstall(config = self.app_config,
model_manager = self) model_manager = self,
prediction_type_helper = ask_user_for_prediction_type,
)
installed = set() installed = set()
if not self.app_config.autoimport_dir: if not self.app_config.autoimport_dir:

View File

@ -255,7 +255,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return SchedulerPredictionType.Epsilon return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000: elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction return SchedulerPredictionType.VPrediction
if self.checkpoint_path and self.helper: if self.checkpoint_path and self.helper \
and not self.checkpoint_path.with_suffix('.yaml').exists(): # if a .yaml config file exists, then this step not needed
return self.helper(self.checkpoint_path) return self.helper(self.checkpoint_path)
else: else:
return None return None

View File

@ -578,14 +578,14 @@ class StderrToMessage():
# -------------------------------------------------------- # --------------------------------------------------------
def ask_user_for_prediction_type(model_path: Path, def ask_user_for_prediction_type(model_path: Path,
tui_conn: Connection=None tui_conn: Connection=None
)->Path: )->SchedulerPredictionType:
if tui_conn: if tui_conn:
logger.debug('Waiting for user response...') logger.debug('Waiting for user response...')
return _ask_user_for_pt_tui(model_path, tui_conn) return _ask_user_for_pt_tui(model_path, tui_conn)
else: else:
return _ask_user_for_pt_cmdline(model_path) return _ask_user_for_pt_cmdline(model_path)
def _ask_user_for_pt_cmdline(model_path): def _ask_user_for_pt_cmdline(model_path: Path)->SchedulerPredictionType:
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
print( print(
f""" f"""
@ -608,7 +608,7 @@ Please select the type of the V2 checkpoint named {model_path.name}:
return return
return choice return choice
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->Path: def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->SchedulerPredictionType:
try: try:
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8')) tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
# note that we don't do any status checking here # note that we don't do any status checking here