refactored code; added watermark and nsfw facilities to app config route

This commit is contained in:
Lincoln Stein
2023-07-24 22:02:57 -04:00
parent 4194a0ed99
commit efa615a8fd
7 changed files with 139 additions and 138 deletions

View File

@ -1,9 +1,15 @@
import typing
from enum import Enum
from fastapi import Body
from fastapi.routing import APIRouter
from pathlib import Path
from pydantic import BaseModel, Field
from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.version import __version__
from ..dependencies import ApiDependencies
@ -30,6 +36,10 @@ class AppConfig(BaseModel):
"""App Config Response"""
infill_methods: list[str] = Field(description="List of available infill methods")
upscaling_methods: list[str] = Field(description="List of upscaling methods")
upscaling_models: list[str] = Field(description="List of postprocessing methods")
nsfw_methods: list[str] = Field(description="List of NSFW checking methods")
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
@app_router.get(
@ -46,7 +56,27 @@ async def get_config() -> AppConfig:
infill_methods = ['tile']
if PatchMatch.patchmatch_available():
infill_methods.append('patchmatch')
return AppConfig(infill_methods=infill_methods)
upscaling_methods = ['esrgan']
upscaling_models = []
for model in typing.get_args(ESRGAN_MODELS):
upscaling_models.append(str(Path(model).stem))
nsfw_methods = []
if SafetyChecker.safety_checker_available():
nsfw_methods.append('nsfw_checker')
watermarking_methods = []
if InvisibleWatermark.invisible_watermark_available():
watermarking_methods.append('invisible_watermark')
return AppConfig(
infill_methods=infill_methods,
upscaling_methods=upscaling_methods,
upscaling_models=upscaling_models,
nsfw_methods=nsfw_methods,
watermarking_methods=watermarking_methods,
)
@app_router.get(
"/logging",