mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Adds Restoration module, variations fix, cli args
This commit is contained in:
parent
8c751d342d
commit
c0aa92ea13
49
backend/modules/create_cmd_parser.py
Normal file
49
backend/modules/create_cmd_parser.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from ldm.dream.args import PRECISION_CHOICES
|
||||||
|
|
||||||
|
|
||||||
|
def create_cmd_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="InvokeAI web UI")
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
help="The host to serve on",
|
||||||
|
default="localhost",
|
||||||
|
)
|
||||||
|
parser.add_argument("--port", type=int, help="The port to serve on", default=9090)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cors",
|
||||||
|
nargs="*",
|
||||||
|
type=str,
|
||||||
|
help="Additional allowed origins, comma-separated",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line",
|
||||||
|
)
|
||||||
|
# TODO: Can't get flask to serve images from any dir (saving to the dir does work when specified)
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--output_dir",
|
||||||
|
# default="outputs/",
|
||||||
|
# type=str,
|
||||||
|
# help="Directory for output images",
|
||||||
|
# )
|
||||||
|
parser.add_argument(
|
||||||
|
"-v",
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
help="Enables verbose logging",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--precision",
|
||||||
|
dest="precision",
|
||||||
|
type=str,
|
||||||
|
choices=PRECISION_CHOICES,
|
||||||
|
metavar="PRECISION",
|
||||||
|
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||||
|
default="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
@ -8,6 +8,16 @@ import glob
|
|||||||
import shlex
|
import shlex
|
||||||
import math
|
import math
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(".")
|
||||||
|
|
||||||
|
from argparse import ArgumentTypeError
|
||||||
|
from modules.create_cmd_parser import create_cmd_parser
|
||||||
|
|
||||||
|
parser = create_cmd_parser()
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
from flask_socketio import SocketIO
|
from flask_socketio import SocketIO
|
||||||
from flask import Flask, send_from_directory, url_for, jsonify
|
from flask import Flask, send_from_directory, url_for, jsonify
|
||||||
@ -18,9 +28,9 @@ from threading import Event
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
|
|
||||||
from ldm.gfpgan.gfpgan_tools import run_gfpgan
|
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
|
from ldm.dream.restoration import Restoration
|
||||||
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
|
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
|
||||||
from ldm.dream.args import APP_ID, APP_VERSION, calculate_init_img_hash
|
from ldm.dream.args import APP_ID, APP_VERSION, calculate_init_img_hash
|
||||||
from ldm.dream.conditioning import split_weighted_subprompts
|
from ldm.dream.conditioning import split_weighted_subprompts
|
||||||
@ -31,15 +41,19 @@ from modules.parameters import parameters_to_command
|
|||||||
"""
|
"""
|
||||||
USER CONFIG
|
USER CONFIG
|
||||||
"""
|
"""
|
||||||
|
if opt.cors and "*" in opt.cors:
|
||||||
|
raise ArgumentTypeError('"*" is not an allowed CORS origin')
|
||||||
|
|
||||||
|
|
||||||
output_dir = "outputs/" # Base output directory for images
|
output_dir = "outputs/" # Base output directory for images
|
||||||
# host = 'localhost' # Web & socket.io host
|
host = opt.host # Web & socket.io host
|
||||||
host = "localhost" # Web & socket.io host
|
port = opt.port # Web & socket.io port
|
||||||
port = 9090 # Web & socket.io port
|
verbose = opt.verbose # enables copious socket.io logging
|
||||||
verbose = False # enables copious socket.io logging
|
precision = opt.precision
|
||||||
additional_allowed_origins = [
|
embedding_path = opt.embedding_path
|
||||||
"http://localhost:5173"
|
additional_allowed_origins = (
|
||||||
] # additional CORS allowed origins
|
opt.cors if opt.cors else []
|
||||||
|
) # additional CORS allowed origins
|
||||||
model = "stable-diffusion-1.4"
|
model = "stable-diffusion-1.4"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -47,6 +61,9 @@ END USER CONFIG
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
print("* Initializing, be patient...\n")
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
SERVER SETUP
|
SERVER SETUP
|
||||||
"""
|
"""
|
||||||
@ -103,6 +120,20 @@ class CanceledException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
gfpgan, codeformer, esrgan = None, None, None
|
||||||
|
from ldm.dream.restoration.base import Restoration
|
||||||
|
|
||||||
|
restoration = Restoration()
|
||||||
|
gfpgan, codeformer = restoration.load_face_restore_models()
|
||||||
|
esrgan = restoration.load_esrgan()
|
||||||
|
|
||||||
|
# coreformer.process(self, image, strength, device, seed=None, fidelity=0.75)
|
||||||
|
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||||
|
|
||||||
canceled = Event()
|
canceled = Event()
|
||||||
|
|
||||||
# reduce logging outputs to error
|
# reduce logging outputs to error
|
||||||
@ -110,7 +141,11 @@ transformers.logging.set_verbosity_error()
|
|||||||
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
||||||
|
|
||||||
# Initialize and load model
|
# Initialize and load model
|
||||||
generate = Generate(model)
|
generate = Generate(
|
||||||
|
model,
|
||||||
|
precision=precision,
|
||||||
|
embedding_path=embedding_path,
|
||||||
|
)
|
||||||
generate.load_model()
|
generate.load_model()
|
||||||
|
|
||||||
|
|
||||||
@ -204,7 +239,7 @@ def handle_run_esrgan_event(original_image, esrgan_parameters):
|
|||||||
socketio.emit("progressUpdate", progress)
|
socketio.emit("progressUpdate", progress)
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
|
||||||
image = real_esrgan_upscale(
|
image = esrgan.process(
|
||||||
image=image,
|
image=image,
|
||||||
upsampler_scale=esrgan_parameters["upscale"][0],
|
upsampler_scale=esrgan_parameters["upscale"][0],
|
||||||
strength=esrgan_parameters["upscale"][1],
|
strength=esrgan_parameters["upscale"][1],
|
||||||
@ -275,11 +310,8 @@ def handle_run_gfpgan_event(original_image, gfpgan_parameters):
|
|||||||
socketio.emit("progressUpdate", progress)
|
socketio.emit("progressUpdate", progress)
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
|
||||||
image = run_gfpgan(
|
image = gfpgan.process(
|
||||||
image=image,
|
image=image, strength=gfpgan_parameters["gfpgan_strength"], seed=seed
|
||||||
strength=gfpgan_parameters["gfpgan_strength"],
|
|
||||||
seed=seed,
|
|
||||||
upsampler_scale=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
progress["currentStatus"] = "Saving image"
|
progress["currentStatus"] = "Saving image"
|
||||||
@ -537,7 +569,11 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
|||||||
canceled.clear()
|
canceled.clear()
|
||||||
|
|
||||||
step_index = 1
|
step_index = 1
|
||||||
|
prior_variations = (
|
||||||
|
generation_parameters["with_variations"]
|
||||||
|
if "with_variations" in generation_parameters
|
||||||
|
else []
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
If a result image is used as an init image, and then deleted, we will want to be
|
If a result image is used as an init image, and then deleted, we will want to be
|
||||||
able to use it as an init image in the future. Need to copy it.
|
able to use it as an init image in the future. Need to copy it.
|
||||||
@ -611,13 +647,14 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
|||||||
socketio.emit("progressUpdate", progress)
|
socketio.emit("progressUpdate", progress)
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
|
||||||
def image_done(image, seed):
|
def image_done(image, seed, first_seed):
|
||||||
nonlocal generation_parameters
|
nonlocal generation_parameters
|
||||||
nonlocal esrgan_parameters
|
nonlocal esrgan_parameters
|
||||||
nonlocal gfpgan_parameters
|
nonlocal gfpgan_parameters
|
||||||
nonlocal progress
|
nonlocal progress
|
||||||
|
|
||||||
step_index = 1
|
step_index = 1
|
||||||
|
nonlocal prior_variations
|
||||||
|
|
||||||
progress["currentStatus"] = "Generation complete"
|
progress["currentStatus"] = "Generation complete"
|
||||||
socketio.emit("progressUpdate", progress)
|
socketio.emit("progressUpdate", progress)
|
||||||
@ -626,18 +663,27 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
|||||||
all_parameters = generation_parameters
|
all_parameters = generation_parameters
|
||||||
postprocessing = False
|
postprocessing = False
|
||||||
|
|
||||||
|
if (
|
||||||
|
"variation_amount" in all_parameters
|
||||||
|
and all_parameters["variation_amount"] > 0
|
||||||
|
):
|
||||||
|
first_seed = first_seed or seed
|
||||||
|
this_variation = [[seed, all_parameters["variation_amount"]]]
|
||||||
|
all_parameters["with_variations"] = prior_variations + this_variation
|
||||||
|
|
||||||
if esrgan_parameters:
|
if esrgan_parameters:
|
||||||
progress["currentStatus"] = "Upscaling"
|
progress["currentStatus"] = "Upscaling"
|
||||||
progress["currentStatusHasSteps"] = False
|
progress["currentStatusHasSteps"] = False
|
||||||
socketio.emit("progressUpdate", progress)
|
socketio.emit("progressUpdate", progress)
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
|
||||||
image = real_esrgan_upscale(
|
image = esrgan.process(
|
||||||
image=image,
|
image=image,
|
||||||
strength=esrgan_parameters["strength"],
|
|
||||||
upsampler_scale=esrgan_parameters["level"],
|
upsampler_scale=esrgan_parameters["level"],
|
||||||
|
strength=esrgan_parameters["strength"],
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
postprocessing = True
|
postprocessing = True
|
||||||
all_parameters["upscale"] = [
|
all_parameters["upscale"] = [
|
||||||
esrgan_parameters["level"],
|
esrgan_parameters["level"],
|
||||||
@ -650,16 +696,13 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
|||||||
socketio.emit("progressUpdate", progress)
|
socketio.emit("progressUpdate", progress)
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
|
||||||
image = run_gfpgan(
|
image = gfpgan.process(
|
||||||
image=image,
|
image=image, strength=gfpgan_parameters["strength"], seed=seed
|
||||||
strength=gfpgan_parameters["strength"],
|
|
||||||
seed=seed,
|
|
||||||
upsampler_scale=1,
|
|
||||||
)
|
)
|
||||||
postprocessing = True
|
postprocessing = True
|
||||||
all_parameters["gfpgan_strength"] = gfpgan_parameters["strength"]
|
all_parameters["gfpgan_strength"] = gfpgan_parameters["strength"]
|
||||||
|
|
||||||
all_parameters["seed"] = seed
|
all_parameters["seed"] = first_seed
|
||||||
progress["currentStatus"] = "Saving image"
|
progress["currentStatus"] = "Saving image"
|
||||||
socketio.emit("progressUpdate", progress)
|
socketio.emit("progressUpdate", progress)
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
File diff suppressed because one or more lines are too long
4
frontend/dist/index.html
vendored
4
frontend/dist/index.html
vendored
@ -3,8 +3,8 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>Stable Diffusion Dream Server</title>
|
<title>InvokeAI Stable Diffusion Dream Server</title>
|
||||||
<script type="module" crossorigin src="/assets/index.727a397b.js"></script>
|
<script type="module" crossorigin src="/assets/index.632c341a.js"></script>
|
||||||
<link rel="stylesheet" href="/assets/index.447eb2a9.css">
|
<link rel="stylesheet" href="/assets/index.447eb2a9.css">
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
|
@ -92,7 +92,7 @@ const CurrentImageButtons = ({
|
|||||||
colorScheme={'gray'}
|
colorScheme={'gray'}
|
||||||
flexGrow={1}
|
flexGrow={1}
|
||||||
variant={'outline'}
|
variant={'outline'}
|
||||||
isDisabled={!['txt2img', 'img2img'].includes(image.metadata.image.type)}
|
isDisabled={!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)}
|
||||||
onClick={handleClickUseAllParameters}
|
onClick={handleClickUseAllParameters}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ const CurrentImageButtons = ({
|
|||||||
colorScheme={'gray'}
|
colorScheme={'gray'}
|
||||||
flexGrow={1}
|
flexGrow={1}
|
||||||
variant={'outline'}
|
variant={'outline'}
|
||||||
isDisabled={!image.metadata.image.seed}
|
isDisabled={!image?.metadata?.image?.seed}
|
||||||
onClick={handleClickUseSeed}
|
onClick={handleClickUseSeed}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
@ -53,8 +53,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
|
|
||||||
const handleClickSetSeed = (e: SyntheticEvent) => {
|
const handleClickSetSeed = (e: SyntheticEvent) => {
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
// Non-null assertion: this button is not rendered unless this exists
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
|
||||||
dispatch(setSeed(image.metadata.image.seed));
|
dispatch(setSeed(image.metadata.image.seed));
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -109,7 +107,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
/>
|
/>
|
||||||
</DeleteImageModal>
|
</DeleteImageModal>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
{['txt2img', 'img2img'].includes(image.metadata.image.type) && (
|
{['txt2img', 'img2img'].includes(image?.metadata?.image?.type) && (
|
||||||
<Tooltip label="Use all parameters">
|
<Tooltip label="Use all parameters">
|
||||||
<IconButton
|
<IconButton
|
||||||
aria-label="Use all parameters"
|
aria-label="Use all parameters"
|
||||||
@ -121,7 +119,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
/>
|
/>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
)}
|
)}
|
||||||
{image.metadata.image.seed && (
|
{image?.metadata?.image?.seed && (
|
||||||
<Tooltip label="Use seed">
|
<Tooltip label="Use seed">
|
||||||
<IconButton
|
<IconButton
|
||||||
aria-label="Use seed"
|
aria-label="Use seed"
|
||||||
|
@ -95,7 +95,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const jsonBgColor = useColorModeValue('blackAlpha.100', 'whiteAlpha.100');
|
const jsonBgColor = useColorModeValue('blackAlpha.100', 'whiteAlpha.100');
|
||||||
|
|
||||||
const metadata = image.metadata.image;
|
const metadata = image?.metadata?.image || {};
|
||||||
const {
|
const {
|
||||||
type,
|
type,
|
||||||
postprocessing,
|
postprocessing,
|
||||||
@ -119,12 +119,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
const metadataJSON = JSON.stringify(metadata, null, 2);
|
const metadataJSON = JSON.stringify(metadata, null, 2);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex gap={1} direction={'column'} overflowY={'scroll'} width={'100%'}>
|
||||||
gap={1}
|
|
||||||
direction={'column'}
|
|
||||||
overflowY={'scroll'}
|
|
||||||
width={'100%'}
|
|
||||||
>
|
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<Text fontWeight={'semibold'}>File:</Text>
|
<Text fontWeight={'semibold'}>File:</Text>
|
||||||
<Link href={image.url} isExternal>
|
<Link href={image.url} isExternal>
|
||||||
@ -132,7 +127,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
<ExternalLinkIcon mx="2px" />
|
<ExternalLinkIcon mx="2px" />
|
||||||
</Link>
|
</Link>
|
||||||
</Flex>
|
</Flex>
|
||||||
{Object.keys(metadata).length ? (
|
{Object.keys(metadata).length > 0 ? (
|
||||||
<>
|
<>
|
||||||
{type && <MetadataItem label="Type" value={type} />}
|
{type && <MetadataItem label="Type" value={type} />}
|
||||||
{['esrgan', 'gfpgan'].includes(type) && (
|
{['esrgan', 'gfpgan'].includes(type) && (
|
||||||
@ -288,9 +283,9 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
)}
|
)}
|
||||||
<Flex gap={2} direction={'column'}>
|
<Flex gap={2} direction={'column'}>
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<Tooltip label={`Copy JSON`}>
|
<Tooltip label={`Copy metadata JSON`}>
|
||||||
<IconButton
|
<IconButton
|
||||||
aria-label="Copy JSON"
|
aria-label="Copy metadata JSON"
|
||||||
icon={<FaCopy />}
|
icon={<FaCopy />}
|
||||||
size={'xs'}
|
size={'xs'}
|
||||||
variant={'ghost'}
|
variant={'ghost'}
|
||||||
@ -298,7 +293,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
||||||
/>
|
/>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
<Text fontWeight={'semibold'}>JSON:</Text>
|
<Text fontWeight={'semibold'}>Metadata JSON:</Text>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Box
|
<Box
|
||||||
// maxHeight={200}
|
// maxHeight={200}
|
||||||
|
@ -132,6 +132,7 @@ const SeedVariationOptions = () => {
|
|||||||
step={0.01}
|
step={0.01}
|
||||||
min={0}
|
min={0}
|
||||||
max={1}
|
max={1}
|
||||||
|
isDisabled={!shouldGenerateVariations}
|
||||||
onChange={handleChangevariationAmount}
|
onChange={handleChangevariationAmount}
|
||||||
/>
|
/>
|
||||||
<FormControl
|
<FormControl
|
||||||
|
Loading…
Reference in New Issue
Block a user