Adds Restoration module, variations fix, cli args

This commit is contained in:
psychedelicious 2022-09-22 00:32:26 +10:00 committed by Lincoln Stein
parent 8c751d342d
commit c0aa92ea13
8 changed files with 132 additions and 46 deletions

View File

@ -0,0 +1,49 @@
import argparse
import os
from ldm.dream.args import PRECISION_CHOICES
def create_cmd_parser():
parser = argparse.ArgumentParser(description="InvokeAI web UI")
parser.add_argument(
"--host",
type=str,
help="The host to serve on",
default="localhost",
)
parser.add_argument("--port", type=int, help="The port to serve on", default=9090)
parser.add_argument(
"--cors",
nargs="*",
type=str,
help="Additional allowed origins, comma-separated",
)
parser.add_argument(
"--embedding_path",
type=str,
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line",
)
# TODO: Can't get flask to serve images from any dir (saving to the dir does work when specified)
# parser.add_argument(
# "--output_dir",
# default="outputs/",
# type=str,
# help="Directory for output images",
# )
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Enables verbose logging",
)
parser.add_argument(
"--precision",
dest="precision",
type=str,
choices=PRECISION_CHOICES,
metavar="PRECISION",
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default="auto",
)
return parser

View File

@ -8,6 +8,16 @@ import glob
import shlex import shlex
import math import math
import shutil import shutil
import sys
sys.path.append(".")
from argparse import ArgumentTypeError
from modules.create_cmd_parser import create_cmd_parser
parser = create_cmd_parser()
opt = parser.parse_args()
from flask_socketio import SocketIO from flask_socketio import SocketIO
from flask import Flask, send_from_directory, url_for, jsonify from flask import Flask, send_from_directory, url_for, jsonify
@ -18,9 +28,9 @@ from threading import Event
from uuid import uuid4 from uuid import uuid4
from send2trash import send2trash from send2trash import send2trash
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
from ldm.gfpgan.gfpgan_tools import run_gfpgan
from ldm.generate import Generate from ldm.generate import Generate
from ldm.dream.restoration import Restoration
from ldm.dream.pngwriter import PngWriter, retrieve_metadata from ldm.dream.pngwriter import PngWriter, retrieve_metadata
from ldm.dream.args import APP_ID, APP_VERSION, calculate_init_img_hash from ldm.dream.args import APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.dream.conditioning import split_weighted_subprompts from ldm.dream.conditioning import split_weighted_subprompts
@ -31,15 +41,19 @@ from modules.parameters import parameters_to_command
""" """
USER CONFIG USER CONFIG
""" """
if opt.cors and "*" in opt.cors:
raise ArgumentTypeError('"*" is not an allowed CORS origin')
output_dir = "outputs/" # Base output directory for images output_dir = "outputs/" # Base output directory for images
# host = 'localhost' # Web & socket.io host host = opt.host # Web & socket.io host
host = "localhost" # Web & socket.io host port = opt.port # Web & socket.io port
port = 9090 # Web & socket.io port verbose = opt.verbose # enables copious socket.io logging
verbose = False # enables copious socket.io logging precision = opt.precision
additional_allowed_origins = [ embedding_path = opt.embedding_path
"http://localhost:5173" additional_allowed_origins = (
] # additional CORS allowed origins opt.cors if opt.cors else []
) # additional CORS allowed origins
model = "stable-diffusion-1.4" model = "stable-diffusion-1.4"
""" """
@ -47,6 +61,9 @@ END USER CONFIG
""" """
print("* Initializing, be patient...\n")
""" """
SERVER SETUP SERVER SETUP
""" """
@ -103,6 +120,20 @@ class CanceledException(Exception):
pass pass
try:
gfpgan, codeformer, esrgan = None, None, None
from ldm.dream.restoration.base import Restoration
restoration = Restoration()
gfpgan, codeformer = restoration.load_face_restore_models()
esrgan = restoration.load_esrgan()
# coreformer.process(self, image, strength, device, seed=None, fidelity=0.75)
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
canceled = Event() canceled = Event()
# reduce logging outputs to error # reduce logging outputs to error
@ -110,7 +141,11 @@ transformers.logging.set_verbosity_error()
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
# Initialize and load model # Initialize and load model
generate = Generate(model) generate = Generate(
model,
precision=precision,
embedding_path=embedding_path,
)
generate.load_model() generate.load_model()
@ -204,7 +239,7 @@ def handle_run_esrgan_event(original_image, esrgan_parameters):
socketio.emit("progressUpdate", progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
image = real_esrgan_upscale( image = esrgan.process(
image=image, image=image,
upsampler_scale=esrgan_parameters["upscale"][0], upsampler_scale=esrgan_parameters["upscale"][0],
strength=esrgan_parameters["upscale"][1], strength=esrgan_parameters["upscale"][1],
@ -275,11 +310,8 @@ def handle_run_gfpgan_event(original_image, gfpgan_parameters):
socketio.emit("progressUpdate", progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
image = run_gfpgan( image = gfpgan.process(
image=image, image=image, strength=gfpgan_parameters["gfpgan_strength"], seed=seed
strength=gfpgan_parameters["gfpgan_strength"],
seed=seed,
upsampler_scale=1,
) )
progress["currentStatus"] = "Saving image" progress["currentStatus"] = "Saving image"
@ -537,7 +569,11 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
canceled.clear() canceled.clear()
step_index = 1 step_index = 1
prior_variations = (
generation_parameters["with_variations"]
if "with_variations" in generation_parameters
else []
)
""" """
If a result image is used as an init image, and then deleted, we will want to be If a result image is used as an init image, and then deleted, we will want to be
able to use it as an init image in the future. Need to copy it. able to use it as an init image in the future. Need to copy it.
@ -611,13 +647,14 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
socketio.emit("progressUpdate", progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
def image_done(image, seed): def image_done(image, seed, first_seed):
nonlocal generation_parameters nonlocal generation_parameters
nonlocal esrgan_parameters nonlocal esrgan_parameters
nonlocal gfpgan_parameters nonlocal gfpgan_parameters
nonlocal progress nonlocal progress
step_index = 1 step_index = 1
nonlocal prior_variations
progress["currentStatus"] = "Generation complete" progress["currentStatus"] = "Generation complete"
socketio.emit("progressUpdate", progress) socketio.emit("progressUpdate", progress)
@ -626,18 +663,27 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
all_parameters = generation_parameters all_parameters = generation_parameters
postprocessing = False postprocessing = False
if (
"variation_amount" in all_parameters
and all_parameters["variation_amount"] > 0
):
first_seed = first_seed or seed
this_variation = [[seed, all_parameters["variation_amount"]]]
all_parameters["with_variations"] = prior_variations + this_variation
if esrgan_parameters: if esrgan_parameters:
progress["currentStatus"] = "Upscaling" progress["currentStatus"] = "Upscaling"
progress["currentStatusHasSteps"] = False progress["currentStatusHasSteps"] = False
socketio.emit("progressUpdate", progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
image = real_esrgan_upscale( image = esrgan.process(
image=image, image=image,
strength=esrgan_parameters["strength"],
upsampler_scale=esrgan_parameters["level"], upsampler_scale=esrgan_parameters["level"],
strength=esrgan_parameters["strength"],
seed=seed, seed=seed,
) )
postprocessing = True postprocessing = True
all_parameters["upscale"] = [ all_parameters["upscale"] = [
esrgan_parameters["level"], esrgan_parameters["level"],
@ -650,16 +696,13 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
socketio.emit("progressUpdate", progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
image = run_gfpgan( image = gfpgan.process(
image=image, image=image, strength=gfpgan_parameters["strength"], seed=seed
strength=gfpgan_parameters["strength"],
seed=seed,
upsampler_scale=1,
) )
postprocessing = True postprocessing = True
all_parameters["gfpgan_strength"] = gfpgan_parameters["strength"] all_parameters["gfpgan_strength"] = gfpgan_parameters["strength"]
all_parameters["seed"] = seed all_parameters["seed"] = first_seed
progress["currentStatus"] = "Saving image" progress["currentStatus"] = "Saving image"
socketio.emit("progressUpdate", progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)

File diff suppressed because one or more lines are too long

View File

@ -3,8 +3,8 @@
<head> <head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Stable Diffusion Dream Server</title> <title>InvokeAI Stable Diffusion Dream Server</title>
<script type="module" crossorigin src="/assets/index.727a397b.js"></script> <script type="module" crossorigin src="/assets/index.632c341a.js"></script>
<link rel="stylesheet" href="/assets/index.447eb2a9.css"> <link rel="stylesheet" href="/assets/index.447eb2a9.css">
</head> </head>
<body> <body>

View File

@ -92,7 +92,7 @@ const CurrentImageButtons = ({
colorScheme={'gray'} colorScheme={'gray'}
flexGrow={1} flexGrow={1}
variant={'outline'} variant={'outline'}
isDisabled={!['txt2img', 'img2img'].includes(image.metadata.image.type)} isDisabled={!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)}
onClick={handleClickUseAllParameters} onClick={handleClickUseAllParameters}
/> />
@ -101,7 +101,7 @@ const CurrentImageButtons = ({
colorScheme={'gray'} colorScheme={'gray'}
flexGrow={1} flexGrow={1}
variant={'outline'} variant={'outline'}
isDisabled={!image.metadata.image.seed} isDisabled={!image?.metadata?.image?.seed}
onClick={handleClickUseSeed} onClick={handleClickUseSeed}
/> />

View File

@ -53,8 +53,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleClickSetSeed = (e: SyntheticEvent) => { const handleClickSetSeed = (e: SyntheticEvent) => {
e.stopPropagation(); e.stopPropagation();
// Non-null assertion: this button is not rendered unless this exists
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
dispatch(setSeed(image.metadata.image.seed)); dispatch(setSeed(image.metadata.image.seed));
}; };
@ -109,7 +107,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
/> />
</DeleteImageModal> </DeleteImageModal>
</Tooltip> </Tooltip>
{['txt2img', 'img2img'].includes(image.metadata.image.type) && ( {['txt2img', 'img2img'].includes(image?.metadata?.image?.type) && (
<Tooltip label="Use all parameters"> <Tooltip label="Use all parameters">
<IconButton <IconButton
aria-label="Use all parameters" aria-label="Use all parameters"
@ -121,7 +119,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
/> />
</Tooltip> </Tooltip>
)} )}
{image.metadata.image.seed && ( {image?.metadata?.image?.seed && (
<Tooltip label="Use seed"> <Tooltip label="Use seed">
<IconButton <IconButton
aria-label="Use seed" aria-label="Use seed"

View File

@ -95,7 +95,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const jsonBgColor = useColorModeValue('blackAlpha.100', 'whiteAlpha.100'); const jsonBgColor = useColorModeValue('blackAlpha.100', 'whiteAlpha.100');
const metadata = image.metadata.image; const metadata = image?.metadata?.image || {};
const { const {
type, type,
postprocessing, postprocessing,
@ -119,12 +119,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const metadataJSON = JSON.stringify(metadata, null, 2); const metadataJSON = JSON.stringify(metadata, null, 2);
return ( return (
<Flex <Flex gap={1} direction={'column'} overflowY={'scroll'} width={'100%'}>
gap={1}
direction={'column'}
overflowY={'scroll'}
width={'100%'}
>
<Flex gap={2}> <Flex gap={2}>
<Text fontWeight={'semibold'}>File:</Text> <Text fontWeight={'semibold'}>File:</Text>
<Link href={image.url} isExternal> <Link href={image.url} isExternal>
@ -132,7 +127,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
<ExternalLinkIcon mx="2px" /> <ExternalLinkIcon mx="2px" />
</Link> </Link>
</Flex> </Flex>
{Object.keys(metadata).length ? ( {Object.keys(metadata).length > 0 ? (
<> <>
{type && <MetadataItem label="Type" value={type} />} {type && <MetadataItem label="Type" value={type} />}
{['esrgan', 'gfpgan'].includes(type) && ( {['esrgan', 'gfpgan'].includes(type) && (
@ -288,9 +283,9 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
)} )}
<Flex gap={2} direction={'column'}> <Flex gap={2} direction={'column'}>
<Flex gap={2}> <Flex gap={2}>
<Tooltip label={`Copy JSON`}> <Tooltip label={`Copy metadata JSON`}>
<IconButton <IconButton
aria-label="Copy JSON" aria-label="Copy metadata JSON"
icon={<FaCopy />} icon={<FaCopy />}
size={'xs'} size={'xs'}
variant={'ghost'} variant={'ghost'}
@ -298,7 +293,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
onClick={() => navigator.clipboard.writeText(metadataJSON)} onClick={() => navigator.clipboard.writeText(metadataJSON)}
/> />
</Tooltip> </Tooltip>
<Text fontWeight={'semibold'}>JSON:</Text> <Text fontWeight={'semibold'}>Metadata JSON:</Text>
</Flex> </Flex>
<Box <Box
// maxHeight={200} // maxHeight={200}

View File

@ -132,6 +132,7 @@ const SeedVariationOptions = () => {
step={0.01} step={0.01}
min={0} min={0}
max={1} max={1}
isDisabled={!shouldGenerateVariations}
onChange={handleChangevariationAmount} onChange={handleChangevariationAmount}
/> />
<FormControl <FormControl