mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into development
This commit is contained in:
commit
88e3b6d310
49
backend/modules/create_cmd_parser.py
Normal file
49
backend/modules/create_cmd_parser.py
Normal file
@ -0,0 +1,49 @@
|
||||
import argparse
|
||||
import os
|
||||
from ldm.dream.args import PRECISION_CHOICES
|
||||
|
||||
|
||||
def create_cmd_parser():
|
||||
parser = argparse.ArgumentParser(description="InvokeAI web UI")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
help="The host to serve on",
|
||||
default="localhost",
|
||||
)
|
||||
parser.add_argument("--port", type=int, help="The port to serve on", default=9090)
|
||||
parser.add_argument(
|
||||
"--cors",
|
||||
nargs="*",
|
||||
type=str,
|
||||
help="Additional allowed origins, comma-separated",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding_path",
|
||||
type=str,
|
||||
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line",
|
||||
)
|
||||
# TODO: Can't get flask to serve images from any dir (saving to the dir does work when specified)
|
||||
# parser.add_argument(
|
||||
# "--output_dir",
|
||||
# default="outputs/",
|
||||
# type=str,
|
||||
# help="Directory for output images",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Enables verbose logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
dest="precision",
|
||||
type=str,
|
||||
choices=PRECISION_CHOICES,
|
||||
metavar="PRECISION",
|
||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||
default="auto",
|
||||
)
|
||||
|
||||
return parser
|
@ -8,6 +8,16 @@ import glob
|
||||
import shlex
|
||||
import math
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from argparse import ArgumentTypeError
|
||||
from modules.create_cmd_parser import create_cmd_parser
|
||||
|
||||
parser = create_cmd_parser()
|
||||
opt = parser.parse_args()
|
||||
|
||||
|
||||
from flask_socketio import SocketIO
|
||||
from flask import Flask, send_from_directory, url_for, jsonify
|
||||
@ -18,9 +28,9 @@ from threading import Event
|
||||
from uuid import uuid4
|
||||
from send2trash import send2trash
|
||||
|
||||
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
|
||||
from ldm.gfpgan.gfpgan_tools import run_gfpgan
|
||||
|
||||
from ldm.generate import Generate
|
||||
from ldm.dream.restoration import Restoration
|
||||
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
|
||||
from ldm.dream.args import APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.dream.conditioning import split_weighted_subprompts
|
||||
@ -31,15 +41,19 @@ from modules.parameters import parameters_to_command
|
||||
"""
|
||||
USER CONFIG
|
||||
"""
|
||||
if opt.cors and "*" in opt.cors:
|
||||
raise ArgumentTypeError('"*" is not an allowed CORS origin')
|
||||
|
||||
|
||||
output_dir = "outputs/" # Base output directory for images
|
||||
# host = 'localhost' # Web & socket.io host
|
||||
host = "localhost" # Web & socket.io host
|
||||
port = 9090 # Web & socket.io port
|
||||
verbose = False # enables copious socket.io logging
|
||||
additional_allowed_origins = [
|
||||
"http://localhost:5173"
|
||||
] # additional CORS allowed origins
|
||||
host = opt.host # Web & socket.io host
|
||||
port = opt.port # Web & socket.io port
|
||||
verbose = opt.verbose # enables copious socket.io logging
|
||||
precision = opt.precision
|
||||
embedding_path = opt.embedding_path
|
||||
additional_allowed_origins = (
|
||||
opt.cors if opt.cors else []
|
||||
) # additional CORS allowed origins
|
||||
model = "stable-diffusion-1.4"
|
||||
|
||||
"""
|
||||
@ -47,6 +61,9 @@ END USER CONFIG
|
||||
"""
|
||||
|
||||
|
||||
print("* Initializing, be patient...\n")
|
||||
|
||||
|
||||
"""
|
||||
SERVER SETUP
|
||||
"""
|
||||
@ -103,6 +120,20 @@ class CanceledException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
from ldm.dream.restoration.base import Restoration
|
||||
|
||||
restoration = Restoration()
|
||||
gfpgan, codeformer = restoration.load_face_restore_models()
|
||||
esrgan = restoration.load_esrgan()
|
||||
|
||||
# coreformer.process(self, image, strength, device, seed=None, fidelity=0.75)
|
||||
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
|
||||
canceled = Event()
|
||||
|
||||
# reduce logging outputs to error
|
||||
@ -110,7 +141,11 @@ transformers.logging.set_verbosity_error()
|
||||
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
||||
|
||||
# Initialize and load model
|
||||
generate = Generate(model)
|
||||
generate = Generate(
|
||||
model,
|
||||
precision=precision,
|
||||
embedding_path=embedding_path,
|
||||
)
|
||||
generate.load_model()
|
||||
|
||||
|
||||
@ -204,7 +239,7 @@ def handle_run_esrgan_event(original_image, esrgan_parameters):
|
||||
socketio.emit("progressUpdate", progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = real_esrgan_upscale(
|
||||
image = esrgan.process(
|
||||
image=image,
|
||||
upsampler_scale=esrgan_parameters["upscale"][0],
|
||||
strength=esrgan_parameters["upscale"][1],
|
||||
@ -275,11 +310,8 @@ def handle_run_gfpgan_event(original_image, gfpgan_parameters):
|
||||
socketio.emit("progressUpdate", progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = run_gfpgan(
|
||||
image=image,
|
||||
strength=gfpgan_parameters["gfpgan_strength"],
|
||||
seed=seed,
|
||||
upsampler_scale=1,
|
||||
image = gfpgan.process(
|
||||
image=image, strength=gfpgan_parameters["gfpgan_strength"], seed=seed
|
||||
)
|
||||
|
||||
progress["currentStatus"] = "Saving image"
|
||||
@ -537,7 +569,11 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
||||
canceled.clear()
|
||||
|
||||
step_index = 1
|
||||
|
||||
prior_variations = (
|
||||
generation_parameters["with_variations"]
|
||||
if "with_variations" in generation_parameters
|
||||
else []
|
||||
)
|
||||
"""
|
||||
If a result image is used as an init image, and then deleted, we will want to be
|
||||
able to use it as an init image in the future. Need to copy it.
|
||||
@ -611,13 +647,14 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
||||
socketio.emit("progressUpdate", progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
def image_done(image, seed):
|
||||
def image_done(image, seed, first_seed):
|
||||
nonlocal generation_parameters
|
||||
nonlocal esrgan_parameters
|
||||
nonlocal gfpgan_parameters
|
||||
nonlocal progress
|
||||
|
||||
step_index = 1
|
||||
nonlocal prior_variations
|
||||
|
||||
progress["currentStatus"] = "Generation complete"
|
||||
socketio.emit("progressUpdate", progress)
|
||||
@ -626,18 +663,27 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
||||
all_parameters = generation_parameters
|
||||
postprocessing = False
|
||||
|
||||
if (
|
||||
"variation_amount" in all_parameters
|
||||
and all_parameters["variation_amount"] > 0
|
||||
):
|
||||
first_seed = first_seed or seed
|
||||
this_variation = [[seed, all_parameters["variation_amount"]]]
|
||||
all_parameters["with_variations"] = prior_variations + this_variation
|
||||
|
||||
if esrgan_parameters:
|
||||
progress["currentStatus"] = "Upscaling"
|
||||
progress["currentStatusHasSteps"] = False
|
||||
socketio.emit("progressUpdate", progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = real_esrgan_upscale(
|
||||
image = esrgan.process(
|
||||
image=image,
|
||||
strength=esrgan_parameters["strength"],
|
||||
upsampler_scale=esrgan_parameters["level"],
|
||||
strength=esrgan_parameters["strength"],
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
postprocessing = True
|
||||
all_parameters["upscale"] = [
|
||||
esrgan_parameters["level"],
|
||||
@ -650,16 +696,13 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
||||
socketio.emit("progressUpdate", progress)
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = run_gfpgan(
|
||||
image=image,
|
||||
strength=gfpgan_parameters["strength"],
|
||||
seed=seed,
|
||||
upsampler_scale=1,
|
||||
image = gfpgan.process(
|
||||
image=image, strength=gfpgan_parameters["strength"], seed=seed
|
||||
)
|
||||
postprocessing = True
|
||||
all_parameters["gfpgan_strength"] = gfpgan_parameters["strength"]
|
||||
|
||||
all_parameters["seed"] = seed
|
||||
all_parameters["seed"] = first_seed
|
||||
progress["currentStatus"] = "Saving image"
|
||||
socketio.emit("progressUpdate", progress)
|
||||
eventlet.sleep(0)
|
||||
|
File diff suppressed because one or more lines are too long
4
frontend/dist/index.html
vendored
4
frontend/dist/index.html
vendored
@ -3,8 +3,8 @@
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Stable Diffusion Dream Server</title>
|
||||
<script type="module" crossorigin src="/assets/index.727a397b.js"></script>
|
||||
<title>InvokeAI Stable Diffusion Dream Server</title>
|
||||
<script type="module" crossorigin src="/assets/index.632c341a.js"></script>
|
||||
<link rel="stylesheet" href="/assets/index.447eb2a9.css">
|
||||
</head>
|
||||
<body>
|
||||
|
@ -92,7 +92,7 @@ const CurrentImageButtons = ({
|
||||
colorScheme={'gray'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
isDisabled={!['txt2img', 'img2img'].includes(image.metadata.image.type)}
|
||||
isDisabled={!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)}
|
||||
onClick={handleClickUseAllParameters}
|
||||
/>
|
||||
|
||||
@ -101,7 +101,7 @@ const CurrentImageButtons = ({
|
||||
colorScheme={'gray'}
|
||||
flexGrow={1}
|
||||
variant={'outline'}
|
||||
isDisabled={!image.metadata.image.seed}
|
||||
isDisabled={!image?.metadata?.image?.seed}
|
||||
onClick={handleClickUseSeed}
|
||||
/>
|
||||
|
||||
|
@ -53,8 +53,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
|
||||
const handleClickSetSeed = (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
// Non-null assertion: this button is not rendered unless this exists
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
dispatch(setSeed(image.metadata.image.seed));
|
||||
};
|
||||
|
||||
@ -109,7 +107,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
/>
|
||||
</DeleteImageModal>
|
||||
</Tooltip>
|
||||
{['txt2img', 'img2img'].includes(image.metadata.image.type) && (
|
||||
{['txt2img', 'img2img'].includes(image?.metadata?.image?.type) && (
|
||||
<Tooltip label="Use all parameters">
|
||||
<IconButton
|
||||
aria-label="Use all parameters"
|
||||
@ -121,7 +119,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
{image.metadata.image.seed && (
|
||||
{image?.metadata?.image?.seed && (
|
||||
<Tooltip label="Use seed">
|
||||
<IconButton
|
||||
aria-label="Use seed"
|
||||
|
@ -95,7 +95,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const jsonBgColor = useColorModeValue('blackAlpha.100', 'whiteAlpha.100');
|
||||
|
||||
const metadata = image.metadata.image;
|
||||
const metadata = image?.metadata?.image || {};
|
||||
const {
|
||||
type,
|
||||
postprocessing,
|
||||
@ -119,12 +119,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
const metadataJSON = JSON.stringify(metadata, null, 2);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
gap={1}
|
||||
direction={'column'}
|
||||
overflowY={'scroll'}
|
||||
width={'100%'}
|
||||
>
|
||||
<Flex gap={1} direction={'column'} overflowY={'scroll'} width={'100%'}>
|
||||
<Flex gap={2}>
|
||||
<Text fontWeight={'semibold'}>File:</Text>
|
||||
<Link href={image.url} isExternal>
|
||||
@ -132,7 +127,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
<ExternalLinkIcon mx="2px" />
|
||||
</Link>
|
||||
</Flex>
|
||||
{Object.keys(metadata).length ? (
|
||||
{Object.keys(metadata).length > 0 ? (
|
||||
<>
|
||||
{type && <MetadataItem label="Type" value={type} />}
|
||||
{['esrgan', 'gfpgan'].includes(type) && (
|
||||
@ -288,9 +283,9 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
)}
|
||||
<Flex gap={2} direction={'column'}>
|
||||
<Flex gap={2}>
|
||||
<Tooltip label={`Copy JSON`}>
|
||||
<Tooltip label={`Copy metadata JSON`}>
|
||||
<IconButton
|
||||
aria-label="Copy JSON"
|
||||
aria-label="Copy metadata JSON"
|
||||
icon={<FaCopy />}
|
||||
size={'xs'}
|
||||
variant={'ghost'}
|
||||
@ -298,7 +293,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Text fontWeight={'semibold'}>JSON:</Text>
|
||||
<Text fontWeight={'semibold'}>Metadata JSON:</Text>
|
||||
</Flex>
|
||||
<Box
|
||||
// maxHeight={200}
|
||||
|
@ -132,6 +132,7 @@ const SeedVariationOptions = () => {
|
||||
step={0.01}
|
||||
min={0}
|
||||
max={1}
|
||||
isDisabled={!shouldGenerateVariations}
|
||||
onChange={handleChangevariationAmount}
|
||||
/>
|
||||
<FormControl
|
||||
|
@ -1 +1,4 @@
|
||||
from .base import Restoration
|
||||
'''
|
||||
Initialization file for the ldm.dream.restoration package
|
||||
'''
|
||||
from .base import Restoration
|
||||
|
@ -839,7 +839,6 @@ class Generate:
|
||||
return model
|
||||
|
||||
def _load_img(self, path, width, height, fit=False):
|
||||
print(f'DEBUG: path = {path}')
|
||||
assert os.path.exists(path), f'>> {path}: File not found'
|
||||
|
||||
# with Image.open(path) as img:
|
||||
|
@ -8,6 +8,7 @@ import shlex
|
||||
import copy
|
||||
import warnings
|
||||
import time
|
||||
sys.path.append('.') # corrects a weird problem on Macs
|
||||
import ldm.dream.readline
|
||||
from ldm.dream.args import Args, metadata_dumps, metadata_from_png
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
@ -36,7 +37,6 @@ def main():
|
||||
sys.exit(-1)
|
||||
|
||||
print('* Initializing, be patient...\n')
|
||||
sys.path.append('.')
|
||||
from ldm.generate import Generate
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
@ -170,9 +170,10 @@ def main_loop(gen, opt, infile):
|
||||
|
||||
if opt.init_img:
|
||||
try:
|
||||
oldargs = metadata_from_png(opt.init_img)
|
||||
opt.prompt = oldargs.prompt
|
||||
print(f'>> Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
|
||||
if not opt.prompt:
|
||||
oldargs = metadata_from_png(opt.init_img)
|
||||
opt.prompt = oldargs.prompt
|
||||
print(f'>> Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
|
||||
except AttributeError:
|
||||
pass
|
||||
except KeyError:
|
||||
|
Loading…
Reference in New Issue
Block a user