mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
closes #4768
This commit is contained in:
@ -20,6 +20,7 @@ from multiprocessing import Process
|
||||
from multiprocessing.connection import Connection, Pipe
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import Optional
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
@ -630,21 +631,23 @@ def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None)
|
||||
return _ask_user_for_pt_cmdline(model_path)
|
||||
|
||||
|
||||
def _ask_user_for_pt_cmdline(model_path: Path) -> SchedulerPredictionType:
|
||||
def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]:
|
||||
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
|
||||
print(
|
||||
f"""
|
||||
Please select the type of the V2 checkpoint named {model_path.name}:
|
||||
[1] A model based on Stable Diffusion v2 trained on 512 pixel images (SD-2-base)
|
||||
[2] A model based on Stable Diffusion v2 trained on 768 pixel images (SD-2-768)
|
||||
[3] Skip this model and come back later.
|
||||
Please select the scheduler prediction type of the checkpoint named {model_path.name}:
|
||||
[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images
|
||||
[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models
|
||||
[3] Accept the best guess; you can fix it in the Web UI later
|
||||
"""
|
||||
)
|
||||
choice = None
|
||||
ok = False
|
||||
while not ok:
|
||||
try:
|
||||
choice = input("select> ").strip()
|
||||
choice = input("select [3]> ").strip()
|
||||
if not choice:
|
||||
return None
|
||||
choice = choices[int(choice) - 1]
|
||||
ok = True
|
||||
except (ValueError, IndexError):
|
||||
@ -655,22 +658,18 @@ Please select the type of the V2 checkpoint named {model_path.name}:
|
||||
|
||||
|
||||
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType:
|
||||
try:
|
||||
tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8"))
|
||||
# note that we don't do any status checking here
|
||||
response = tui_conn.recv_bytes().decode("utf-8")
|
||||
if response is None:
|
||||
return None
|
||||
elif response == "epsilon":
|
||||
return SchedulerPredictionType.epsilon
|
||||
elif response == "v":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif response == "abort":
|
||||
logger.info("Conversion aborted")
|
||||
return None
|
||||
else:
|
||||
return response
|
||||
except Exception:
|
||||
tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8"))
|
||||
# note that we don't do any status checking here
|
||||
response = tui_conn.recv_bytes().decode("utf-8")
|
||||
if response is None:
|
||||
return None
|
||||
elif response == "epsilon":
|
||||
return SchedulerPredictionType.epsilon
|
||||
elif response == "v":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif response == "guess":
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user