esrgan and its models are now nested in app config route

This commit is contained in:
Lincoln Stein 2023-07-24 22:17:22 -04:00
parent efa615a8fd
commit 91e903c8ab

View File

@ -23,6 +23,10 @@ class LogLevel(int, Enum):
Error = logging.ERROR
Critical = logging.CRITICAL
class Upscaler(BaseModel):
upscaling_method: str = Field(description="Name of upscaling method")
upscaling_models: list[str] = Field(description="List of upscaling models for this method")
app_router = APIRouter(prefix="/v1/app", tags=["app"])
@ -36,8 +40,7 @@ class AppConfig(BaseModel):
"""App Config Response"""
infill_methods: list[str] = Field(description="List of available infill methods")
upscaling_methods: list[str] = Field(description="List of upscaling methods")
upscaling_models: list[str] = Field(description="List of postprocessing methods")
upscaling_methods: list[Upscaler] = Field(description="List of upscaling methods")
nsfw_methods: list[str] = Field(description="List of NSFW checking methods")
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
@ -57,10 +60,14 @@ async def get_config() -> AppConfig:
if PatchMatch.patchmatch_available():
infill_methods.append('patchmatch')
upscaling_methods = ['esrgan']
upscaling_models = []
for model in typing.get_args(ESRGAN_MODELS):
upscaling_models.append(str(Path(model).stem))
upscaler = Upscaler(
upscaling_method = 'esrgan',
upscaling_models = upscaling_models
)
nsfw_methods = []
if SafetyChecker.safety_checker_available():
@ -72,8 +79,7 @@ async def get_config() -> AppConfig:
return AppConfig(
infill_methods=infill_methods,
upscaling_methods=upscaling_methods,
upscaling_models=upscaling_models,
upscaling_methods=[upscaler],
nsfw_methods=nsfw_methods,
watermarking_methods=watermarking_methods,
)