From 823e098b7c06c026a1c692c2eb206559dd85a55b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 26 Jun 2023 16:18:16 -0400 Subject: [PATCH] prompt user for prediction type when autoimporting a v2 model without .yaml file don't ask user for prediction type of a config.yaml provided --- invokeai/backend/install/model_install_backend.py | 1 - invokeai/backend/model_management/model_manager.py | 6 +++++- invokeai/backend/model_management/model_probe.py | 3 ++- invokeai/frontend/install/model_install.py | 6 +++--- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index dcc0eac902..ac25316d9e 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -411,7 +411,6 @@ def update_autoimport_dir(autodir: Path): outfile.write(yaml) tmpfile.replace(invokeai_config_path) - # ------------------------------------- def yes_or_no(prompt: str, default_yes=True): default = "y" if default_yes else "n" diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index c0d5122886..b88550d63b 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -712,8 +712,12 @@ class ModelManager(object): ''' # avoid circular import from invokeai.backend.install.model_install_backend import ModelInstall + from invokeai.frontend.install.model_install import ask_user_for_prediction_type + installer = ModelInstall(config = self.app_config, - model_manager = self) + model_manager = self, + prediction_type_helper = ask_user_for_prediction_type, + ) installed = set() if not self.app_config.autoimport_dir: diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 2b6eb7e7be..42f4bb6225 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -255,7 +255,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase): return SchedulerPredictionType.Epsilon elif checkpoint["global_step"] == 110000: return SchedulerPredictionType.VPrediction - if self.checkpoint_path and self.helper: + if self.checkpoint_path and self.helper \ + and not self.checkpoint_path.with_suffix('.yaml').exists(): # if a .yaml config file exists, then this step not needed return self.helper(self.checkpoint_path) else: return None diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 183be03173..900426eac6 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -578,14 +578,14 @@ class StderrToMessage(): # -------------------------------------------------------- def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection=None - )->Path: + )->SchedulerPredictionType: if tui_conn: logger.debug('Waiting for user response...') return _ask_user_for_pt_tui(model_path, tui_conn) else: return _ask_user_for_pt_cmdline(model_path) -def _ask_user_for_pt_cmdline(model_path): +def _ask_user_for_pt_cmdline(model_path: Path)->SchedulerPredictionType: choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] print( f""" @@ -608,7 +608,7 @@ Please select the type of the V2 checkpoint named {model_path.name}: return return choice -def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->Path: +def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->SchedulerPredictionType: try: tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8')) # note that we don't do any status checking here