mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
prompt user for prediction type when autoimporting a v2 model without .yaml file
don't ask user for prediction type of a config.yaml provided
This commit is contained in:
parent
f67dec7f0c
commit
823e098b7c
@ -411,7 +411,6 @@ def update_autoimport_dir(autodir: Path):
|
||||
outfile.write(yaml)
|
||||
tmpfile.replace(invokeai_config_path)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
default = "y" if default_yes else "n"
|
||||
|
@ -712,8 +712,12 @@ class ModelManager(object):
|
||||
'''
|
||||
# avoid circular import
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||
|
||||
installer = ModelInstall(config = self.app_config,
|
||||
model_manager = self)
|
||||
model_manager = self,
|
||||
prediction_type_helper = ask_user_for_prediction_type,
|
||||
)
|
||||
|
||||
installed = set()
|
||||
if not self.app_config.autoimport_dir:
|
||||
|
@ -255,7 +255,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
return SchedulerPredictionType.Epsilon
|
||||
elif checkpoint["global_step"] == 110000:
|
||||
return SchedulerPredictionType.VPrediction
|
||||
if self.checkpoint_path and self.helper:
|
||||
if self.checkpoint_path and self.helper \
|
||||
and not self.checkpoint_path.with_suffix('.yaml').exists(): # if a .yaml config file exists, then this step not needed
|
||||
return self.helper(self.checkpoint_path)
|
||||
else:
|
||||
return None
|
||||
|
@ -578,14 +578,14 @@ class StderrToMessage():
|
||||
# --------------------------------------------------------
|
||||
def ask_user_for_prediction_type(model_path: Path,
|
||||
tui_conn: Connection=None
|
||||
)->Path:
|
||||
)->SchedulerPredictionType:
|
||||
if tui_conn:
|
||||
logger.debug('Waiting for user response...')
|
||||
return _ask_user_for_pt_tui(model_path, tui_conn)
|
||||
else:
|
||||
return _ask_user_for_pt_cmdline(model_path)
|
||||
|
||||
def _ask_user_for_pt_cmdline(model_path):
|
||||
def _ask_user_for_pt_cmdline(model_path: Path)->SchedulerPredictionType:
|
||||
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
|
||||
print(
|
||||
f"""
|
||||
@ -608,7 +608,7 @@ Please select the type of the V2 checkpoint named {model_path.name}:
|
||||
return
|
||||
return choice
|
||||
|
||||
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->Path:
|
||||
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->SchedulerPredictionType:
|
||||
try:
|
||||
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
|
||||
# note that we don't do any status checking here
|
||||
|
Loading…
Reference in New Issue
Block a user