diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index ff8198103b..e37184a77b 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -22,6 +22,10 @@ class LogLevel(int, Enum): Warning = logging.WARNING Error = logging.ERROR Critical = logging.CRITICAL + +class Upscaler(BaseModel): + upscaling_method: str = Field(description="Name of upscaling method") + upscaling_models: list[str] = Field(description="List of upscaling models for this method") app_router = APIRouter(prefix="/v1/app", tags=["app"]) @@ -36,8 +40,7 @@ class AppConfig(BaseModel): """App Config Response""" infill_methods: list[str] = Field(description="List of available infill methods") - upscaling_methods: list[str] = Field(description="List of upscaling methods") - upscaling_models: list[str] = Field(description="List of postprocessing methods") + upscaling_methods: list[Upscaler] = Field(description="List of upscaling methods") nsfw_methods: list[str] = Field(description="List of NSFW checking methods") watermarking_methods: list[str] = Field(description="List of invisible watermark methods") @@ -57,11 +60,15 @@ async def get_config() -> AppConfig: if PatchMatch.patchmatch_available(): infill_methods.append('patchmatch') - upscaling_methods = ['esrgan'] + upscaling_models = [] for model in typing.get_args(ESRGAN_MODELS): upscaling_models.append(str(Path(model).stem)) - + upscaler = Upscaler( + upscaling_method = 'esrgan', + upscaling_models = upscaling_models + ) + nsfw_methods = [] if SafetyChecker.safety_checker_available(): nsfw_methods.append('nsfw_checker') @@ -72,8 +79,7 @@ async def get_config() -> AppConfig: return AppConfig( infill_methods=infill_methods, - upscaling_methods=upscaling_methods, - upscaling_models=upscaling_models, + upscaling_methods=[upscaler], nsfw_methods=nsfw_methods, watermarking_methods=watermarking_methods, )