Adds Restoration module, variations fix, cli args

This commit is contained in:
psychedelicious 2022-09-22 00:32:26 +10:00
parent a80119f826
commit 7cf7ba42fb
8 changed files with 132 additions and 46 deletions

View File

@ -0,0 +1,49 @@
import argparse
import os
from ldm.dream.args import PRECISION_CHOICES
def create_cmd_parser():
parser = argparse.ArgumentParser(description="InvokeAI web UI")
parser.add_argument(
"--host",
type=str,
help="The host to serve on",
default="localhost",
)
parser.add_argument("--port", type=int, help="The port to serve on", default=9090)
parser.add_argument(
"--cors",
nargs="*",
type=str,
help="Additional allowed origins, comma-separated",
)
parser.add_argument(
"--embedding_path",
type=str,
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line",
)
# TODO: Can't get flask to serve images from any dir (saving to the dir does work when specified)
# parser.add_argument(
# "--output_dir",
# default="outputs/",
# type=str,
# help="Directory for output images",
# )
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Enables verbose logging",
)
parser.add_argument(
"--precision",
dest="precision",
type=str,
choices=PRECISION_CHOICES,
metavar="PRECISION",
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default="auto",
)
return parser

View File

@ -8,6 +8,16 @@ import glob
import shlex
import math
import shutil
import sys
sys.path.append(".")
from argparse import ArgumentTypeError
from modules.create_cmd_parser import create_cmd_parser
parser = create_cmd_parser()
opt = parser.parse_args()
from flask_socketio import SocketIO
from flask import Flask, send_from_directory, url_for, jsonify
@ -18,9 +28,9 @@ from threading import Event
from uuid import uuid4
from send2trash import send2trash
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
from ldm.gfpgan.gfpgan_tools import run_gfpgan
from ldm.generate import Generate
from ldm.dream.restoration import Restoration
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
from ldm.dream.args import APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.dream.conditioning import split_weighted_subprompts
@ -31,15 +41,19 @@ from modules.parameters import parameters_to_command
"""
USER CONFIG
"""
if opt.cors and "*" in opt.cors:
raise ArgumentTypeError('"*" is not an allowed CORS origin')
output_dir = "outputs/" # Base output directory for images
# host = 'localhost' # Web & socket.io host
host = "localhost" # Web & socket.io host
port = 9090 # Web & socket.io port
verbose = False # enables copious socket.io logging
additional_allowed_origins = [
"http://localhost:5173"
] # additional CORS allowed origins
host = opt.host # Web & socket.io host
port = opt.port # Web & socket.io port
verbose = opt.verbose # enables copious socket.io logging
precision = opt.precision
embedding_path = opt.embedding_path
additional_allowed_origins = (
opt.cors if opt.cors else []
) # additional CORS allowed origins
model = "stable-diffusion-1.4"
"""
@ -47,6 +61,9 @@ END USER CONFIG
"""
print("* Initializing, be patient...\n")
"""
SERVER SETUP
"""
@ -103,6 +120,20 @@ class CanceledException(Exception):
pass
try:
gfpgan, codeformer, esrgan = None, None, None
from ldm.dream.restoration.base import Restoration
restoration = Restoration()
gfpgan, codeformer = restoration.load_face_restore_models()
esrgan = restoration.load_esrgan()
# coreformer.process(self, image, strength, device, seed=None, fidelity=0.75)
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
canceled = Event()
# reduce logging outputs to error
@ -110,7 +141,11 @@ transformers.logging.set_verbosity_error()
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
# Initialize and load model
generate = Generate(model)
generate = Generate(
model,
precision=precision,
embedding_path=embedding_path,
)
generate.load_model()
@ -204,7 +239,7 @@ def handle_run_esrgan_event(original_image, esrgan_parameters):
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = real_esrgan_upscale(
image = esrgan.process(
image=image,
upsampler_scale=esrgan_parameters["upscale"][0],
strength=esrgan_parameters["upscale"][1],
@ -275,11 +310,8 @@ def handle_run_gfpgan_event(original_image, gfpgan_parameters):
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = run_gfpgan(
image=image,
strength=gfpgan_parameters["gfpgan_strength"],
seed=seed,
upsampler_scale=1,
image = gfpgan.process(
image=image, strength=gfpgan_parameters["gfpgan_strength"], seed=seed
)
progress["currentStatus"] = "Saving image"
@ -537,7 +569,11 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
canceled.clear()
step_index = 1
prior_variations = (
generation_parameters["with_variations"]
if "with_variations" in generation_parameters
else []
)
"""
If a result image is used as an init image, and then deleted, we will want to be
able to use it as an init image in the future. Need to copy it.
@ -611,13 +647,14 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
def image_done(image, seed):
def image_done(image, seed, first_seed):
nonlocal generation_parameters
nonlocal esrgan_parameters
nonlocal gfpgan_parameters
nonlocal progress
step_index = 1
nonlocal prior_variations
progress["currentStatus"] = "Generation complete"
socketio.emit("progressUpdate", progress)
@ -626,18 +663,27 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
all_parameters = generation_parameters
postprocessing = False
if (
"variation_amount" in all_parameters
and all_parameters["variation_amount"] > 0
):
first_seed = first_seed or seed
this_variation = [[seed, all_parameters["variation_amount"]]]
all_parameters["with_variations"] = prior_variations + this_variation
if esrgan_parameters:
progress["currentStatus"] = "Upscaling"
progress["currentStatusHasSteps"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = real_esrgan_upscale(
image = esrgan.process(
image=image,
strength=esrgan_parameters["strength"],
upsampler_scale=esrgan_parameters["level"],
strength=esrgan_parameters["strength"],
seed=seed,
)
postprocessing = True
all_parameters["upscale"] = [
esrgan_parameters["level"],
@ -650,16 +696,13 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = run_gfpgan(
image=image,
strength=gfpgan_parameters["strength"],
seed=seed,
upsampler_scale=1,
image = gfpgan.process(
image=image, strength=gfpgan_parameters["strength"], seed=seed
)
postprocessing = True
all_parameters["gfpgan_strength"] = gfpgan_parameters["strength"]
all_parameters["seed"] = seed
all_parameters["seed"] = first_seed
progress["currentStatus"] = "Saving image"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)

File diff suppressed because one or more lines are too long

View File

@ -3,8 +3,8 @@
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Stable Diffusion Dream Server</title>
<script type="module" crossorigin src="/assets/index.727a397b.js"></script>
<title>InvokeAI Stable Diffusion Dream Server</title>
<script type="module" crossorigin src="/assets/index.632c341a.js"></script>
<link rel="stylesheet" href="/assets/index.447eb2a9.css">
</head>
<body>

View File

@ -92,7 +92,7 @@ const CurrentImageButtons = ({
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={!['txt2img', 'img2img'].includes(image.metadata.image.type)}
isDisabled={!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)}
onClick={handleClickUseAllParameters}
/>
@ -101,7 +101,7 @@ const CurrentImageButtons = ({
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={!image.metadata.image.seed}
isDisabled={!image?.metadata?.image?.seed}
onClick={handleClickUseSeed}
/>

View File

@ -53,8 +53,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleClickSetSeed = (e: SyntheticEvent) => {
e.stopPropagation();
// Non-null assertion: this button is not rendered unless this exists
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
dispatch(setSeed(image.metadata.image.seed));
};
@ -109,7 +107,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
/>
</DeleteImageModal>
</Tooltip>
{['txt2img', 'img2img'].includes(image.metadata.image.type) && (
{['txt2img', 'img2img'].includes(image?.metadata?.image?.type) && (
<Tooltip label="Use all parameters">
<IconButton
aria-label="Use all parameters"
@ -121,7 +119,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
/>
</Tooltip>
)}
{image.metadata.image.seed && (
{image?.metadata?.image?.seed && (
<Tooltip label="Use seed">
<IconButton
aria-label="Use seed"

View File

@ -95,7 +95,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch();
const jsonBgColor = useColorModeValue('blackAlpha.100', 'whiteAlpha.100');
const metadata = image.metadata.image;
const metadata = image?.metadata?.image || {};
const {
type,
postprocessing,
@ -119,12 +119,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const metadataJSON = JSON.stringify(metadata, null, 2);
return (
<Flex
gap={1}
direction={'column'}
overflowY={'scroll'}
width={'100%'}
>
<Flex gap={1} direction={'column'} overflowY={'scroll'} width={'100%'}>
<Flex gap={2}>
<Text fontWeight={'semibold'}>File:</Text>
<Link href={image.url} isExternal>
@ -132,7 +127,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
{Object.keys(metadata).length ? (
{Object.keys(metadata).length > 0 ? (
<>
{type && <MetadataItem label="Type" value={type} />}
{['esrgan', 'gfpgan'].includes(type) && (
@ -288,9 +283,9 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
)}
<Flex gap={2} direction={'column'}>
<Flex gap={2}>
<Tooltip label={`Copy JSON`}>
<Tooltip label={`Copy metadata JSON`}>
<IconButton
aria-label="Copy JSON"
aria-label="Copy metadata JSON"
icon={<FaCopy />}
size={'xs'}
variant={'ghost'}
@ -298,7 +293,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight={'semibold'}>JSON:</Text>
<Text fontWeight={'semibold'}>Metadata JSON:</Text>
</Flex>
<Box
// maxHeight={200}

View File

@ -132,6 +132,7 @@ const SeedVariationOptions = () => {
step={0.01}
min={0}
max={1}
isDisabled={!shouldGenerateVariations}
onChange={handleChangevariationAmount}
/>
<FormControl