prompt user for prediction type when autoimporting a v2 model without .yaml file

don't ask user for prediction type of a config.yaml provided
This commit is contained in:
Lincoln Stein
2023-06-26 16:18:16 -04:00
parent f67dec7f0c
commit 823e098b7c
4 changed files with 10 additions and 6 deletions

View File

@ -578,14 +578,14 @@ class StderrToMessage():
# --------------------------------------------------------
def ask_user_for_prediction_type(model_path: Path,
tui_conn: Connection=None
)->Path:
)->SchedulerPredictionType:
if tui_conn:
logger.debug('Waiting for user response...')
return _ask_user_for_pt_tui(model_path, tui_conn)
else:
return _ask_user_for_pt_cmdline(model_path)
def _ask_user_for_pt_cmdline(model_path):
def _ask_user_for_pt_cmdline(model_path: Path)->SchedulerPredictionType:
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
print(
f"""
@ -608,7 +608,7 @@ Please select the type of the V2 checkpoint named {model_path.name}:
return
return choice
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->Path:
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->SchedulerPredictionType:
try:
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
# note that we don't do any status checking here