Remove image generation node dependencies on generate.py (#2902)

# Remove node dependencies on generate.py

This is a draft PR in which I am replacing `generate.py` with a cleaner,
more structured interface to the underlying image generation routines.
The basic code pattern to generate an image using the new API is this:

```
from invokeai.backend import ModelManager, Txt2Img, Img2Img

manager = ModelManager('/data/lstein/invokeai-main/configs/models.yaml')
model = manager.get_model('stable-diffusion-1.5')
txt2img = Txt2Img(model)
outputs = txt2img.generate(prompt='banana sushi', steps=12, scheduler='k_euler_a', iterations=5)

# generate() returns an iterator
for next_output in outputs:
    print(next_output.image, next_output.seed)

outputs = Img2Img(model).generate(prompt='strawberry` sushi', init_img='./banana_sushi.png')
output = next(outputs)
output.image.save('strawberries.png')
```

### model management

The `ModelManager` handles model selection and initialization. Its
`get_model()` method will return a `dict` with the following keys:
`model`, `model_name`,`hash`, `width`, and `height`, where `model` is
the actual StableDiffusionGeneratorPIpeline. If `get_model()` is called
without a model name, it will return whatever is defined as the default
in `models.yaml`, or the first entry if no default is designated.

### InvokeAIGenerator

The abstract base class `InvokeAIGenerator` is subclassed into into
`Txt2Img`, `Img2Img`, `Inpaint` and `Embiggen`. The constructor for
these classes takes the model dict returned by
`model_manager.get_model()` and optionally an
`InvokeAIGeneratorBasicParams` object, which encapsulates all the
parameters in common among `Txt2Img`, `Img2Img` etc. If you don't
provide the basic params, a reasonable set of defaults will be chosen.
Any of these parameters can be overridden at `generate()` time.

These classes are defined in `invokeai.backend.generator`, but they are
also exported by `invokeai.backend` as shown in the example below.

```
from invokeai.backend import InvokeAIGeneratorBasicParams, Img2Img
params = InvokeAIGeneratorBasicParams(
    perlin = 0.15
    steps = 30
   scheduler = 'k_lms'
)
img2img = Img2Img(model, params)
outputs = img2img.generate(scheduler='k_heun')
```

Note that we were able to override the basic params in the call to
`generate()`

The `generate()` method will returns an iterator over a series of
`InvokeAIGeneratorOutput` objects. These objects contain the PIL image,
the seed, the model name and hash, and attributes for all the parameters
used to generate the object (you can also get these as a dict). The
`iterations` argument controls how many objects will be returned,
defaulting to 1. Pass `None` to get an infinite iterator.

Given the proposed use of `compel` to generate a templated series of
prompts, I thought the API would benefit from a style that lets you loop
over the output results indefinitely. I did consider returning a single
`InvokeAIGeneratorOutput` object in the event that `iterations=1`, but I
think it's dangerous for a method to return different types of result
under different circumstances.

Changing the model is as easy as this:
```
model = manager.get_model('inkspot-2.0`)
txt2img = Txt2Img(model)
```

### Node and legacy support

With respect to `Nodes`, I have written `model_manager_initializer` and
`restoration_services` modules that return `model_manager` and
`restoration` services respectively. The latter is used by the face
reconstruction and upscaling nodes. There is no longer any reference to
`Generate` in the `app` tree.

I have confirmed that `txt2img` and `img2img` work in the nodes client.
I have not tested `embiggen` or `inpaint` yet. pytests are passing, with
some warnings that I don't think are related to what I did.

The legacy WebUI and CLI are still working off `Generate` (which has not
yet been removed from the source tree) and fully functional.

I've finished all the tasks on my TODO list:

- [x] Update the pytests, which are failing due to dangling references
to `generate`
- [x] Rewrite the `reconstruct.py` and `upscale.py` nodes to call
directly into the postprocessing modules rather than going through
`Generate`
- [x] Update the pytests, which are failing due to dangling references
to `generate`
This commit is contained in:
Lincoln Stein 2023-03-11 21:48:23 -05:00 committed by GitHub
commit 1aaad9336f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1970 additions and 467 deletions

View File

@ -4,7 +4,8 @@ import os
from argparse import Namespace
from ...backend import Globals
from ..services.generate_initializer import get_generate
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState
from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue
@ -37,18 +38,16 @@ class ApiDependencies:
invoker: Invoker = None
@staticmethod
def initialize(args, config, event_handler_id: int):
Globals.try_patchmatch = args.patchmatch
Globals.always_use_cpu = args.always_use_cpu
Globals.internet_available = args.internet_available and check_internet()
Globals.disable_xformers = not args.xformers
Globals.ckpt_convert = args.ckpt_convert
def initialize(config, event_handler_id: int):
Globals.try_patchmatch = config.patchmatch
Globals.always_use_cpu = config.always_use_cpu
Globals.internet_available = config.internet_available and check_internet()
Globals.disable_xformers = not config.xformers
Globals.ckpt_convert = config.ckpt_convert
# TODO: Use a logger
print(f">> Internet connectivity is {Globals.internet_available}")
generate = get_generate(args, config)
events = FastAPIEventService(event_handler_id)
output_folder = os.path.abspath(
@ -61,7 +60,7 @@ class ApiDependencies:
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
generate=generate,
model_manager=get_model_manager(config),
events=events,
images=images,
queue=MemoryInvocationQueue(),
@ -69,6 +68,7 @@ class ApiDependencies:
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config),
)
ApiDependencies.invoker = Invoker(services)

View File

@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from inspect import signature
@ -53,11 +52,11 @@ config = {}
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event():
args = Args()
config = args.parse_args()
config = Args()
config.parse_args()
ApiDependencies.initialize(
args=args, config=config, event_handler_id=event_handler_id
config=config, event_handler_id=event_handler_id
)

View File

@ -17,7 +17,8 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.generate_initializer import get_generate
from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
@ -126,14 +127,9 @@ def invoke_all(context: CliContext):
def invoke_cli():
args = Args()
config = args.parse_args()
generate = get_generate(args, config)
# NOTE: load model on first use, uncomment to load at startup
# TODO: Make this a config option?
# generate.load_model()
config = Args()
config.parse_args()
model_manager = get_model_manager(config)
events = EventServiceBase()
@ -145,7 +141,7 @@ def invoke_cli():
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
generate=generate,
model_manager=model_manager,
events=events,
images=DiskImageStorage(output_folder),
queue=MemoryInvocationQueue(),
@ -153,6 +149,7 @@ def invoke_cli():
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config),
)
invoker = Invoker(services)

View File

@ -12,12 +12,12 @@ from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
SAMPLER_NAME_VALUES = Literal[
"ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun"
tuple(InvokeAIGenerator.schedulers())
]
# Text to image
class TextToImageInvocation(BaseInvocation):
"""Generates an image using text2img."""
@ -57,19 +57,18 @@ class TextToImageInvocation(BaseInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == "":
self.model = context.services.generate.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
# (right now uses whatever current model is set in model manager)
model= context.services.model_manager.get_model()
outputs = Txt2Img(model).generate(
prompt=self.prompt,
step_callback=step_callback,
**self.dict(
exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generate_output = next(outputs)
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
@ -78,7 +77,7 @@ class TextToImageInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0])
context.services.images.save(image_type, image_name, generate_output.image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
@ -115,23 +114,20 @@ class ImageToImageInvocation(TextToImageInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == "":
self.model = context.services.generate.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
model = context.services.model_manager.get_model()
generator_output = next(
Img2Img(model).generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
)
result_image = results[0][0]
result_image = generator_output.image
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
@ -145,7 +141,6 @@ class ImageToImageInvocation(TextToImageInvocation):
image=ImageField(image_type=image_type, image_name=image_name)
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
@ -180,23 +175,20 @@ class InpaintInvocation(ImageToImageInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == "":
self.model = context.services.generate.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
manager = context.services.model_manager.get_model()
generator_output = next(
Inpaint(model).generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
)
result_image = results[0][0]
result_image = generator_output.image
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?

View File

@ -8,7 +8,6 @@ from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
#fmt: off
@ -23,7 +22,7 @@ class RestoreFaceInvocation(BaseInvocation):
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
results = context.services.generate.upscale_and_reconstruct(
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=None,
strength=self.strength, # GFPGAN strength

View File

@ -26,7 +26,7 @@ class UpscaleInvocation(BaseInvocation):
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
results = context.services.generate.upscale_and_reconstruct(
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
strength=0.0, # GFPGAN strength

View File

@ -1,255 +0,0 @@
import os
import sys
import traceback
from argparse import Namespace
import invokeai.version
from invokeai.backend import Generate, ModelManager
from ...backend import Globals
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
def get_generate(args, config) -> Generate:
if not args.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
report_model_error(
args, FileNotFoundError(f"The file {config_file} could not be found.")
)
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers # type: ignore
transformers.logging.set_verbosity_error()
import diffusers
diffusers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules
gfpgan, codeformer, esrgan = load_face_restoration(args)
# normalize the config directory relative to root
if not os.path.isabs(args.conf):
args.conf = os.path.normpath(os.path.join(Globals.root, args.conf))
if args.embeddings:
if not os.path.isabs(args.embedding_path):
embedding_path = os.path.normpath(
os.path.join(Globals.root, args.embedding_path)
)
else:
embedding_path = args.embedding_path
else:
embedding_path = None
# migrate legacy models
ModelManager.migrate_models()
# load the infile as a list of lines
if args.infile:
try:
if os.path.isfile(args.infile):
infile = open(args.infile, "r", encoding="utf-8")
elif args.infile == "-": # stdin
infile = sys.stdin
else:
raise FileNotFoundError(f"{args.infile} not found.")
except (FileNotFoundError, IOError) as e:
print(f"{e}. Aborting.")
sys.exit(-1)
# creating a Generate object:
try:
gen = Generate(
conf=args.conf,
model=args.model,
sampler_name=args.sampler_name,
embedding_path=embedding_path,
full_precision=args.full_precision,
precision=args.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=args.free_gpu_mem,
safety_checker=args.safety_checker,
max_loaded_models=args.max_loaded_models,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(opt, e)
except (IOError, KeyError) as e:
print(f"{e}. Aborting.")
sys.exit(-1)
if args.seamless:
print(">> changed to seamless tiling mode")
# preload the model
try:
gen.load_model()
except KeyError:
pass
except Exception as e:
report_model_error(args, e)
# try to autoconvert new models
# autoimport new .ckpt files
if path := args.autoconvert:
gen.model_manager.autoconvert_weights(
conf_path=args.conf,
weights_directory=path,
)
return gen
def load_face_restoration(opt):
try:
gfpgan, codeformer, esrgan = None, None, None
if opt.restore or opt.esrgan:
from invokeai.backend.restoration import Restoration
restoration = Restoration()
if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
opt.gfpgan_model_path
)
else:
print(">> Face restoration disabled")
if opt.esrgan:
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
else:
print(">> Upscaling disabled")
else:
print(">> Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
return gfpgan, codeformer, esrgan
def report_model_error(opt: Namespace, e: Exception):
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
print(
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
)
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all:
print(
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else:
response = input(
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
)
if response.startswith(("n", "N")):
return
print("invokeai-configure is launching....\n")
# Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else []
previous_args = sys.argv
sys.argv = ["invokeai-configure"]
sys.argv.extend(root_dir)
sys.argv.extend(config)
if yes_to_all is not None:
for arg in yes_to_all.split():
sys.argv.append(arg)
from invokeai.frontend.install import invokeai_configure
invokeai_configure()
# TODO: Figure out how to restart
# print('** InvokeAI will now restart')
# sys.argv = previous_args
# main() # would rather do a os.exec(), but doesn't exist?
# sys.exit(0)
# Temporary initializer for Generate until we migrate off of it
def old_get_generate(args, config) -> Generate:
# TODO: Remove the need for globals
from invokeai.backend.globals import Globals
# alert - setting globals here
Globals.root = os.path.expanduser(
args.root_dir or os.environ.get("INVOKEAI_ROOT") or os.path.abspath(".")
)
Globals.try_patchmatch = args.patchmatch
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules
gfpgan, codeformer, esrgan = None, None, None
try:
if config.restore or config.esrgan:
from ldm.invoke.restoration import Restoration
restoration = Restoration()
if config.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
config.gfpgan_model_path
)
else:
print(">> Face restoration disabled")
if config.esrgan:
esrgan = restoration.load_esrgan(config.esrgan_bg_tile)
else:
print(">> Upscaling disabled")
else:
print(">> Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
# normalize the config directory relative to root
if not os.path.isabs(config.conf):
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
if config.embeddings:
if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath(
os.path.join(Globals.root, config.embedding_path)
)
else:
embedding_path = None
# TODO: lazy-initialize this by wrapping it
try:
generate = Generate(
conf=config.conf,
model=config.model,
sampler_name=config.sampler_name,
embedding_path=embedding_path,
full_precision=config.full_precision,
precision=config.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=config.free_gpu_mem,
safety_checker=config.safety_checker,
max_loaded_models=config.max_loaded_models,
)
except (FileNotFoundError, TypeError, AssertionError):
# emergency_model_reconfigure() # TODO?
sys.exit(-1)
except (IOError, KeyError) as e:
print(f"{e}. Aborting.")
sys.exit(-1)
generate.free_gpu_mem = config.free_gpu_mem
return generate

View File

@ -1,36 +1,39 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from invokeai.backend import Generate
from invokeai.backend import ModelManager
from .events import EventServiceBase
from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
class InvocationServices:
"""Services that can be used by invocations"""
generate: Generate # TODO: wrap Generate, or split it up from model?
events: EventServiceBase
images: ImageStorageBase
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
processor: "InvocationProcessorABC"
def __init__(
self,
generate: Generate,
events: EventServiceBase,
images: ImageStorageBase,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
self,
model_manager: ModelManager,
events: EventServiceBase,
images: ImageStorageBase,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
):
self.generate = generate
self.model_manager = model_manager
self.events = events
self.images = images
self.queue = queue
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration

View File

@ -0,0 +1,120 @@
import os
import sys
import torch
from argparse import Namespace
from invokeai.backend import Args
from omegaconf import OmegaConf
from pathlib import Path
import invokeai.version
from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
# TODO: Replace with an abstract class base ModelManagerBase
def get_model_manager(config: Args) -> ModelManager:
if not config.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
report_model_error(
config, FileNotFoundError(f"The file {config_file} could not be found.")
)
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers # type: ignore
transformers.logging.set_verbosity_error()
import diffusers
diffusers.logging.set_verbosity_error()
# normalize the config directory relative to root
if not os.path.isabs(config.conf):
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
if config.embeddings:
if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath(
os.path.join(Globals.root, config.embedding_path)
)
else:
embedding_path = config.embedding_path
else:
embedding_path = None
# migrate legacy models
ModelManager.migrate_models()
# creating the model manager
try:
device = torch.device(choose_torch_device())
precision = 'float16' if config.precision=='float16' \
else 'float32' if config.precision=='float32' \
else choose_precision(device)
model_manager = ModelManager(
OmegaConf.load(config.conf),
precision=precision,
device_type=device,
max_loaded_models=config.max_loaded_models,
embedding_path = Path(embedding_path),
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(config, e)
except (IOError, KeyError) as e:
print(f"{e}. Aborting.")
sys.exit(-1)
# try to autoconvert new models
# autoimport new .ckpt files
if path := config.autoconvert:
model_manager.autoconvert_weights(
conf_path=config.conf,
weights_directory=path,
)
return model_manager
def report_model_error(opt: Namespace, e: Exception):
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
print(
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
)
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all:
print(
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else:
response = input(
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
)
if response.startswith(("n", "N")):
return
print("invokeai-configure is launching....\n")
# Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else []
previous_config = sys.argv
sys.argv = ["invokeai-configure"]
sys.argv.extend(root_dir)
sys.argv.extend(config.to_dict())
if yes_to_all is not None:
for arg in yes_to_all.split():
sys.argv.append(arg)
from invokeai.frontend.install import invokeai_configure
invokeai_configure()
# TODO: Figure out how to restart
# print('** InvokeAI will now restart')
# sys.argv = previous_args
# main() # would rather do a os.exec(), but doesn't exist?
# sys.exit(0)

View File

@ -0,0 +1,109 @@
import sys
import traceback
import torch
from ...backend.restoration import Restoration
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
# This should be a real base class for postprocessing functions,
# but right now we just instantiate the existing gfpgan, esrgan
# and codeformer functions.
class RestorationServices:
'''Face restoration and upscaling'''
def __init__(self,args):
try:
gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan:
restoration = Restoration()
if args.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
args.gfpgan_model_path
)
else:
print(">> Face restoration disabled")
if args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else:
print(">> Upscaling disabled")
else:
print(">> Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
self.device = torch.device(choose_torch_device())
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
# note that this one method does gfpgan and codepath reconstruction, as well as
# esrgan upscaling
# TO DO: refactor into separate methods
def upscale_and_reconstruct(
self,
image_list,
facetool="gfpgan",
upscale=None,
upscale_denoise_str=0.75,
strength=0.0,
codeformer_fidelity=0.75,
save_original=False,
image_callback=None,
prefix=None,
):
results = []
for r in image_list:
image, seed = r
try:
if strength > 0:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
print(
">> GFPGAN not found. Face restoration is disabled."
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
print(
">> CodeFormer not found. Face restoration is disabled."
)
else:
cf_device = (
CPU_DEVICE if self.device == MPS_DEVICE else self.device
)
image = self.codeformer.process(
image=image,
strength=strength,
device=cf_device,
seed=seed,
fidelity=codeformer_fidelity,
)
else:
print(">> Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image,
upscale[1],
seed,
int(upscale[0]),
denoise_str=upscale_denoise_str,
)
else:
print(">> ESRGAN is disabled. Image not upscaled.")
except Exception as e:
print(
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
)
if image_callback is not None:
image_callback(image, seed, upscaled=True, use_prefix=prefix)
else:
r[0] = image
results.append([image, seed])
return results

View File

@ -2,6 +2,15 @@
Initialization file for invokeai.backend
"""
from .generate import Generate
from .generator import (
InvokeAIGeneratorBasicParams,
InvokeAIGenerator,
InvokeAIGeneratorOutput,
Txt2Img,
Img2Img,
Inpaint
)
from .model_management import ModelManager
from .safety_checker import SafetyChecker
from .args import Args
from .globals import Globals

View File

@ -25,18 +25,19 @@ from accelerate.utils import set_seed
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.utils.import_utils import is_xformers_available
from omegaconf import OmegaConf
from pathlib import Path
from .args import metadata_from_png
from .generator import infill_methods
from .globals import Globals, global_cache_dir
from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding
from .model_management import ModelManager
from .safety_checker import SafetyChecker
from .prompting import get_uc_and_c_and_ec
from .prompting.conditioning import log_tokenization
from .stable_diffusion import HuggingFaceConceptsLibrary
from .util import choose_precision, choose_torch_device
def fix_func(orig):
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
@ -222,6 +223,7 @@ class Generate:
self.precision,
max_loaded_models=max_loaded_models,
sequential_offload=self.free_gpu_mem,
embedding_path=Path(self.embedding_path),
)
# don't accept invalid models
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
@ -244,31 +246,8 @@ class Generate:
# load safety checker if requested
if safety_checker:
try:
print(">> Initializing NSFW checker")
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_model_path = global_cache_dir("hub")
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id,
local_files_only=True,
cache_dir=safety_model_path,
)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
safety_model_id,
local_files_only=True,
cache_dir=safety_model_path,
)
self.safety_checker.to(self.device)
except Exception:
print(
"** An error was encountered while installing the safety checker:"
)
print(traceback.format_exc())
print(">> Initializing NSFW checker")
self.safety_checker = SafetyChecker(self.device)
else:
print(">> NSFW checker is disabled")
@ -523,15 +502,6 @@ class Generate:
generator.set_variation(self.seed, variation_amount, with_variations)
generator.use_mps_noise = use_mps_noise
checker = (
{
"checker": self.safety_checker,
"extractor": self.safety_feature_extractor,
}
if self.safety_checker
else None
)
results = generator.generate(
prompt,
iterations=iterations,
@ -558,7 +528,7 @@ class Generate:
embiggen_strength=embiggen_strength,
inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius,
safety_checker=checker,
safety_checker=self.safety_checker,
seam_size=seam_size,
seam_blur=seam_blur,
seam_strength=seam_strength,
@ -940,18 +910,6 @@ class Generate:
self.generators = {}
set_seed(random.randrange(0, np.iinfo(np.uint32).max))
if self.embedding_path is not None:
print(f">> Loading embeddings from {self.embedding_path}")
for root, _, files in os.walk(self.embedding_path):
for name in files:
ti_path = os.path.join(root, name)
self.model.textual_inversion_manager.load_textual_inversion(
ti_path, defer_injecting_tokens=True
)
print(
f'>> Textual inversion triggers: {", ".join(sorted(self.model.textual_inversion_manager.get_all_trigger_strings()))}'
)
self.model_name = model_name
self._set_scheduler() # requires self.model_name to be set first
return self.model
@ -998,7 +956,7 @@ class Generate:
):
results = []
for r in image_list:
image, seed = r
image, seed, _ = r
try:
if strength > 0:
if self.gfpgan is not None or self.codeformer is not None:

View File

@ -1,5 +1,13 @@
"""
Initialization file for the invokeai.generator package
"""
from .base import Generator
from .base import (
InvokeAIGenerator,
InvokeAIGeneratorBasicParams,
InvokeAIGeneratorOutput,
Txt2Img,
Img2Img,
Inpaint,
Generator,
)
from .inpaint import infill_methods

View File

@ -4,11 +4,15 @@ including img2img, txt2img, and inpaint
"""
from __future__ import annotations
import itertools
import dataclasses
import diffusers
import os
import random
import traceback
from abc import ABCMeta
from argparse import Namespace
from contextlib import nullcontext
from pathlib import Path
import cv2
import numpy as np
@ -17,13 +21,258 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed
from diffusers import DiffusionPipeline
from tqdm import trange
from typing import List, Iterator, Type
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
import invokeai.assets.web as web_assets
from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
downsampling = 8
CAUTION_IMG = "caution.png"
@dataclass
class InvokeAIGeneratorBasicParams:
seed: int=None
width: int=512
height: int=512
cfg_scale: int=7.5
steps: int=20
ddim_eta: float=0.0
scheduler: int='ddim'
precision: str='float16'
perlin: float=0.0
threshold: int=0.0
seamless: bool=False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
h_symmetry_time_pct: float=None
v_symmetry_time_pct: float=None
variation_amount: float = 0.0
with_variations: list=field(default_factory=list)
safety_checker: SafetyChecker=None
@dataclass
class InvokeAIGeneratorOutput:
'''
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
operation, including the image, its seed, the model name used to generate the image
and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes)
'''
image: Image
seed: int
model_hash: str
attention_maps_images: List[Image]
params: Namespace
# we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work.
class InvokeAIGenerator(metaclass=ABCMeta):
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
def __init__(self,
model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
):
self.model_info=model_info
self.params=params
def generate(self,
prompt: str='',
callback: callable=None,
step_callback: callable=None,
iterations: int=1,
**keyword_args,
)->Iterator[InvokeAIGeneratorOutput]:
'''
Return an iterator across the indicated number of generations.
Each time the iterator is called it will return an InvokeAIGeneratorOutput
object. Use like this:
outputs = txt2img.generate(prompt='banana sushi', iterations=5)
for result in outputs:
print(result.image, result.seed)
In the typical case of wanting to get just a single image, iterations
defaults to 1 and do:
output = next(txt2img.generate(prompt='banana sushi')
Pass None to get an infinite iterator.
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
for o in outputs:
print(o.image, o.seed)
'''
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
model_info = self.model_info
model_name = model_info['model_name']
model:StableDiffusionGeneratorPipeline = model_info['model']
model_hash = model_info['hash']
scheduler: Scheduler = self.get_scheduler(
model=model,
scheduler_name=generator_args.get('scheduler')
)
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
gen_class = self._generator_class()
generator = gen_class(model, self.params.precision)
if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
generator_args.get('with_variations')
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
else:
configure_model_padding(model,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
sampler=scheduler,
**generator_args,
)
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
attention_maps_images=results[0][2],
model_hash = model_hash,
params=Namespace(model_name=model_name,**generator_args),
)
if callback:
callback(output)
yield output
@classmethod
def schedulers(self)->List[str]:
'''
Return list of all the schedulers that we currently handle.
'''
return list(self.scheduler_map.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler
@classmethod
def _generator_class(cls)->Type[Generator]:
'''
In derived classes return the name of the generator to apply.
If you don't override will return the name of the derived
class, which nicely parallels the generator class names.
'''
return Generator
# ------------------------------------
class Txt2Img(InvokeAIGenerator):
@classmethod
def _generator_class(cls):
from .txt2img import Txt2Img
return Txt2Img
# ------------------------------------
class Img2Img(InvokeAIGenerator):
def generate(self,
init_image: Image | torch.FloatTensor,
strength: float=0.75,
**keyword_args
)->List[InvokeAIGeneratorOutput]:
return super().generate(init_image=init_image,
strength=strength,
**keyword_args
)
@classmethod
def _generator_class(cls):
from .img2img import Img2Img
return Img2Img
# ------------------------------------
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img):
def generate(self,
mask_image: Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_strength: float = 0.7,
seam_steps: int = 10,
tile_size: int = 32,
inpaint_replace=False,
infill_method=None,
inpaint_width=None,
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
**keyword_args
)->List[InvokeAIGeneratorOutput]:
return super().generate(
mask_image=mask_image,
seam_size=seam_size,
seam_blur=seam_blur,
seam_strength=seam_strength,
seam_steps=seam_steps,
tile_size=tile_size,
inpaint_replace=inpaint_replace,
infill_method=infill_method,
inpaint_width=inpaint_width,
inpaint_height=inpaint_height,
inpaint_fill=inpaint_fill,
**keyword_args
)
@classmethod
def _generator_class(cls):
from .inpaint import Inpaint
return Inpaint
# ------------------------------------
class Embiggen(Txt2Img):
def generate(
self,
embiggen: list=None,
embiggen_tiles: list = None,
strength: float=0.75,
**kwargs)->List[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
strength=strength,
**kwargs)
@classmethod
def _generator_class(cls):
from .embiggen import Embiggen
return Embiggen
class Generator:
downsampling_factor: int
@ -44,7 +293,6 @@ class Generator:
self.with_variations = []
self.use_mps_noise = False
self.free_gpu_mem = None
self.caution_img = None
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self, prompt, **kwargs):
@ -64,10 +312,10 @@ class Generator:
def generate(
self,
prompt,
init_image,
width,
height,
sampler,
init_image=None,
iterations=1,
seed=None,
image_callback=None,
@ -76,7 +324,7 @@ class Generator:
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
safety_checker: dict = None,
safety_checker: SafetyChecker=None,
free_gpu_mem: bool = False,
**kwargs,
):
@ -130,9 +378,9 @@ class Generator:
image = make_image(x_T)
if self.safety_checker is not None:
image = self.safety_check(image)
image = self.safety_checker.check(image)
results.append([image, seed])
results.append([image, seed, attention_maps_images])
if image_callback is not None:
attention_maps_image = (
@ -292,16 +540,6 @@ class Generator:
seed = random.randrange(0, np.iinfo(np.uint32).max)
return (seed, initial_noise)
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self, width, height):
"""
Returns a tensor filled with random numbers, either form a normal distribution
(txt2img) or from the latent image (img2img, inpaint)
"""
raise NotImplementedError(
"get_noise() must be implemented in a descendent class"
)
def get_perlin_noise(self, width, height):
fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device
# limit noise to only the diffusion image channels, not the mask channels
@ -361,53 +599,6 @@ class Generator:
return v2
def safety_check(self, image: Image.Image):
"""
If the CompViz safety checker flags an NSFW image, we
blur it out.
"""
import diffusers
checker = self.safety_checker["checker"]
extractor = self.safety_checker["extractor"]
features = extractor([image], return_tensors="pt")
features.to(self.model.device)
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = checker(
images=x_image, clip_input=features.pixel_values
)
if has_nsfw_concept[0]:
print(
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
)
return self.blur(image)
else:
return image
def blur(self, input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = self.get_caution_img()
if caution:
blurry.paste(caution, (0, 0), caution)
except FileNotFoundError:
pass
return blurry
def get_caution_img(self):
path = None
if self.caution_img:
return self.caution_img
path = Path(web_assets.__path__[0]) / CAUTION_IMG
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
return self.caution_img
# this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath):

View File

@ -34,8 +34,7 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util import CPU_DEVICE, ask_user, download_with_resume
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = 1
@ -51,23 +50,29 @@ VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
}
class ModelManager(object):
'''
Model manager handles loading, caching, importing, deleting, converting, and editing models.
'''
def __init__(
self,
config: OmegaConf,
device_type: torch.device = CPU_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False,
self,
config: OmegaConf|Path,
device_type: torch.device = CUDA_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False,
embedding_path: Path=None,
):
"""
Initialize with the path to the models.yaml config file,
the torch device type, and precision. The optional
min_avail_mem argument specifies how much unused system
(CPU) memory to preserve. The cache of models in RAM will
grow until this value is approached. Default is 2G.
Initialize with the path to the models.yaml config file or
an initialized OmegaConf dictionary. Optional parameters
are the torch device type, precision, max_loaded_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
# prevent nasty-looking CLIP log message
transformers.logging.set_verbosity_error()
if not isinstance(config, DictConfig):
config = OmegaConf.load(config)
self.config = config
self.precision = precision
self.device = torch.device(device_type)
@ -76,6 +81,7 @@ class ModelManager(object):
self.stack = [] # this is an LRU FIFO
self.current_model = None
self.sequential_offload = sequential_offload
self.embedding_path = embedding_path
def valid_model(self, model_name: str) -> bool:
"""
@ -84,12 +90,15 @@ class ModelManager(object):
"""
return model_name in self.config
def get_model(self, model_name: str):
def get_model(self, model_name: str=None)->dict:
"""
Given a model named identified in models.yaml, return
the model object. If in RAM will load into GPU VRAM.
If on disk, will load from there.
"""
if not model_name:
return self.current_model if self.current_model else self.get_model(self.default_model())
if not self.valid_model(model_name):
print(
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
@ -112,6 +121,7 @@ class ModelManager(object):
else: # we're about to load a new model, so potentially offload the least recently used one
requested_model, width, height, hash = self._load_model(model_name)
self.models[model_name] = {
"model_name": model_name,
"model": requested_model,
"width": width,
"height": height,
@ -121,6 +131,7 @@ class ModelManager(object):
self.current_model = model_name
self._push_newest_model(model_name)
return {
"model_name": model_name,
"model": requested_model,
"width": width,
"height": height,
@ -425,6 +436,7 @@ class ModelManager(object):
height = width
print(f" | Default image dimensions = {width} x {height}")
self._add_embeddings_to_model(pipeline)
return pipeline, width, height, model_hash
@ -1061,6 +1073,19 @@ class ModelManager(object):
self.stack.remove(model_name)
self.stack.append(model_name)
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
if self.embedding_path is not None:
print(f">> Loading embeddings from {self.embedding_path}")
for root, _, files in os.walk(self.embedding_path):
for name in files:
ti_path = os.path.join(root, name)
model.textual_inversion_manager.load_textual_inversion(
ti_path, defer_injecting_tokens=True
)
print(
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
)
def _has_cuda(self) -> bool:
return self.device.type == "cuda"

View File

@ -0,0 +1,82 @@
'''
SafetyChecker class - checks images against the StabilityAI NSFW filter
and blurs images that contain potential NSFW content.
'''
import diffusers
import numpy as np
import torch
import traceback
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from pathlib import Path
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
from .globals import global_cache_dir
from .util import CPU_DEVICE
class SafetyChecker(object):
CAUTION_IMG = "caution.png"
def __init__(self, device: torch.device):
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
self.device = device
try:
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_model_path = global_cache_dir("hub")
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id,
local_files_only=True,
cache_dir=safety_model_path,
)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
safety_model_id,
local_files_only=True,
cache_dir=safety_model_path,
)
except Exception:
print(
"** An error was encountered while installing the safety checker:"
)
print(traceback.format_exc())
def check(self, image: Image.Image):
"""
Check provided image against the StabilityAI safety checker and return
"""
self.safety_checker.to(self.device)
features = self.safety_feature_extractor([image], return_tensors="pt")
features.to(self.device)
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = self.safety_checker(
images=x_image, clip_input=features.pixel_values
)
self.safety_checker.to(CPU_DEVICE) # offload
if has_nsfw_concept[0]:
print(
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
)
return self.blur(image)
else:
return image
def blur(self, input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
if caution := self.caution_img:
blurry.paste(caution, (0, 0), caution)
except FileNotFoundError:
pass
return blurry

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 KiB

179
static/dream_web/index.css Normal file
View File

@ -0,0 +1,179 @@
:root {
--fields-dark:#DCDCDC;
--fields-light:#F5F5F5;
}
* {
font-family: 'Arial';
font-size: 100%;
}
body {
font-size: 1em;
}
textarea {
font-size: 0.95em;
}
header, form, #progress-section {
margin-left: auto;
margin-right: auto;
max-width: 1024px;
text-align: center;
}
fieldset {
border: none;
line-height: 2.2em;
}
fieldset > legend {
width: auto;
margin-left: 0;
margin-right: auto;
font-weight:bold;
}
select, input {
margin-right: 10px;
padding: 2px;
}
input:disabled {
cursor:auto;
}
input[type=submit] {
cursor: pointer;
background-color: #666;
color: white;
}
input[type=checkbox] {
cursor: pointer;
margin-right: 0px;
width: 20px;
height: 20px;
vertical-align: middle;
}
input#seed {
margin-right: 0px;
}
div {
padding: 10px 10px 10px 10px;
}
header {
margin-bottom: 16px;
}
header h1 {
margin-bottom: 0;
font-size: 2em;
}
#search-box {
display: flex;
}
#scaling-inprocess-message {
font-weight: bold;
font-style: italic;
display: none;
}
#prompt {
flex-grow: 1;
padding: 5px 10px 5px 10px;
border: 1px solid #999;
outline: none;
}
#submit {
padding: 5px 10px 5px 10px;
border: 1px solid #999;
}
#reset-all, #remove-image {
margin-top: 12px;
font-size: 0.8em;
background-color: pink;
border: 1px solid #999;
border-radius: 4px;
}
#results {
text-align: center;
margin: auto;
padding-top: 10px;
}
#results figure {
display: inline-block;
margin: 10px;
}
#results figcaption {
font-size: 0.8em;
padding: 3px;
color: #888;
cursor: pointer;
}
#results img {
border-radius: 5px;
object-fit: contain;
background-color: var(--fields-dark);
}
#fieldset-config {
line-height:2em;
}
input[type="number"] {
width: 60px;
}
#seed {
width: 150px;
}
button#reset-seed {
font-size: 1.7em;
background: #efefef;
border: 1px solid #999;
border-radius: 4px;
line-height: 0.8;
margin: 0 10px 0 0;
padding: 0 5px 3px;
vertical-align: middle;
}
label {
white-space: nowrap;
}
#progress-section {
display: none;
}
#progress-image {
width: 30vh;
height: 30vh;
object-fit: contain;
background-color: var(--fields-dark);
}
#cancel-button {
cursor: pointer;
color: red;
}
#txt2img {
background-color: var(--fields-dark);
}
#variations {
background-color: var(--fields-light);
}
#initimg {
background-color: var(--fields-dark);
}
#img2img {
background-color: var(--fields-light);
}
#initimg > :not(legend) {
background-color: var(--fields-light);
margin: .5em;
}
#postprocess, #initimg {
display:flex;
flex-wrap:wrap;
padding: 0;
margin-top: 1em;
background-color: var(--fields-dark);
}
#postprocess > fieldset, #initimg > * {
flex-grow: 1;
}
#postprocess > fieldset {
background-color: var(--fields-dark);
}
#progress-section {
background-color: var(--fields-light);
}
#no-results-message:not(:only-child) {
display: none;
}

187
static/dream_web/index.html Normal file
View File

@ -0,0 +1,187 @@
<html lang="en">
<head>
<title>Stable Diffusion Dream Server</title>
<meta charset="utf-8">
<link rel="icon" type="image/x-icon" href="static/dream_web/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="config.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"
integrity="sha512-q/dWJ3kcmjBLU4Qc47E4A9kTB4m3wuTY7vkFJDTZKjTs8jhyGQnaUrxa0Ytd0ssMZhbNua9hE+E7Qv1j+DyZwA=="
crossorigin="anonymous"></script>
<link rel="stylesheet" href="index.css">
<script src="index.js"></script>
</head>
<body>
<header>
<h1>Stable Diffusion Dream Server</h1>
<div id="about">
For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub
site</a>
</div>
</header>
<main>
<!--
<div id="dropper" style="background-color:red;width:200px;height:200px;">
</div>
-->
<form id="generate-form" method="post" action="api/jobs">
<fieldset id="txt2img">
<legend>
<input type="checkbox" name="enable_generate" id="enable_generate" checked>
<label for="enable_generate">Generate</label>
</legend>
<div id="search-box">
<textarea rows="3" id="prompt" name="prompt"></textarea>
</div>
<label for="iterations">Images to generate:</label>
<input value="1" type="number" id="iterations" name="iterations" size="4">
<label for="steps">Steps:</label>
<input value="50" type="number" id="steps" name="steps">
<label for="cfg_scale">Cfg Scale:</label>
<input value="7.5" type="number" id="cfg_scale" name="cfg_scale" step="any">
<label for="sampler_name">Sampler:</label>
<select id="sampler_name" name="sampler_name" value="k_lms">
<option value="ddim">DDIM</option>
<option value="plms">PLMS</option>
<option value="k_lms" selected>KLMS</option>
<option value="k_dpm_2">KDPM_2</option>
<option value="k_dpm_2_a">KDPM_2A</option>
<option value="k_dpmpp_2">KDPMPP_2</option>
<option value="k_dpmpp_2_a">KDPMPP_2A</option>
<option value="k_euler">KEULER</option>
<option value="k_euler_a">KEULER_A</option>
<option value="k_heun">KHEUN</option>
</select>
<input type="checkbox" name="seamless" id="seamless">
<label for="seamless">Seamless circular tiling</label>
<br>
<label title="Set to multiple of 64" for="width">Width:</label>
<select id="width" name="width" value="512">
<option value="64">64</option>
<option value="128">128</option>
<option value="192">192</option>
<option value="256">256</option>
<option value="320">320</option>
<option value="384">384</option>
<option value="448">448</option>
<option value="512" selected>512</option>
<option value="576">576</option>
<option value="640">640</option>
<option value="704">704</option>
<option value="768">768</option>
<option value="832">832</option>
<option value="896">896</option>
<option value="960">960</option>
<option value="1024">1024</option>
</select>
<label title="Set to multiple of 64" for="height">Height:</label>
<select id="height" name="height" value="512">
<option value="64">64</option>
<option value="128">128</option>
<option value="192">192</option>
<option value="256">256</option>
<option value="320">320</option>
<option value="384">384</option>
<option value="448">448</option>
<option value="512" selected>512</option>
<option value="576">576</option>
<option value="640">640</option>
<option value="704">704</option>
<option value="768">768</option>
<option value="832">832</option>
<option value="896">896</option>
<option value="960">960</option>
<option value="1024">1024</option>
</select>
<label title="Set to 0 for random seed" for="seed">Seed:</label>
<input value="0" type="number" id="seed" name="seed">
<button type="button" id="reset-seed">&olarr;</button>
<input type="checkbox" name="progress_images" id="progress_images">
<label for="progress_images">Display in-progress images (slower)</label>
<div>
<label title="If > 0, adds thresholding to restrict values for k-diffusion samplers (0 disables)" for="threshold">Threshold:</label>
<input value="0" type="number" id="threshold" name="threshold" step="0.1" min="0">
<label title="Perlin: optional 0-1 value adds a percentage of perlin noise to the initial noise" for="perlin">Perlin:</label>
<input value="0" type="number" id="perlin" name="perlin" step="0.01" min="0" max="1">
<button type="button" id="reset-all">Reset to Defaults</button>
</div>
<div id="variations">
<label
title="If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different."
for="variation_amount">Variation amount (0 to disable):</label>
<input value="0" type="number" id="variation_amount" name="variation_amount" step="0.01" min="0" max="1">
<label title="list of variations to apply, in the format `seed:weight,seed:weight,..."
for="with_variations">With variations (seed:weight,seed:weight,...):</label>
<input value="" type="text" id="with_variations" name="with_variations">
</div>
</fieldset>
<fieldset id="initimg">
<legend>
<input type="checkbox" name="enable_init_image" id="enable_init_image" checked>
<label for="enable_init_image">Enable init image</label>
</legend>
<div>
<label title="Upload an image to use img2img" for="initimg">Initial image:</label>
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
<button type="button" id="remove-image">Remove Image</button>
</div>
<fieldset id="img2img">
<legend>
<input type="checkbox" name="enable_img2img" id="enable_img2img" checked>
<label for="enable_img2img">Enable Img2Img</label>
</legend>
<label for="strength">Img2Img Strength:</label>
<input value="0.75" type="number" id="strength" name="strength" step="0.01" min="0" max="1">
<input type="checkbox" id="fit" name="fit" checked>
<label title="Rescale image to fit within requested width and height" for="fit">Fit to width/height:</label>
</fieldset>
</fieldset>
<div id="postprocess">
<fieldset id="gfpgan">
<legend>
<input type="checkbox" name="enable_gfpgan" id="enable_gfpgan">
<label for="enable_gfpgan">Enable gfpgan</label>
</legend>
<label title="Strength of the gfpgan (face fixing) algorithm." for="facetool_strength">GPFGAN Strength:</label>
<input value="0.8" min="0" max="1" type="number" id="facetool_strength" name="facetool_strength" step="0.05">
</fieldset>
<fieldset id="upscale">
<legend>
<input type="checkbox" name="enable_upscale" id="enable_upscale">
<label for="enable_upscale">Enable Upscaling</label>
</legend>
<label title="Upscaling to perform using ESRGAN." for="upscale_level">Upscaling Level:</label>
<select id="upscale_level" name="upscale_level" value="">
<option value="" selected>None</option>
<option value="2">2x</option>
<option value="4">4x</option>
</select>
<label title="Strength of the esrgan (upscaling) algorithm." for="upscale_strength">Upscale Strength:</label>
<input value="0.75" min="0" max="1" type="number" id="upscale_strength" name="upscale_strength" step="0.05">
</fieldset>
</div>
<input type="submit" id="submit" value="Generate">
</form>
<br>
<section id="progress-section">
<div id="progress-container">
<progress id="progress-bar" value="0" max="1"></progress>
<span id="cancel-button" title="Cancel">&#10006;</span>
<br>
<img id="progress-image" src='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'>
<div id="scaling-inprocess-message">
<i><span>Postprocessing...</span><span id="processing_cnt">1</span>/<span id="processing_total">3</span></i>
</div>
</div>
</section>
<div id="results">
</div>
</main>
</body>
</html>

396
static/dream_web/index.js Normal file
View File

@ -0,0 +1,396 @@
const socket = io();
var priorResultsLoadState = {
page: 0,
pages: 1,
per_page: 10,
total: 20,
offset: 0, // number of items generated since last load
loading: false,
initialized: false
};
function loadPriorResults() {
// Fix next page by offset
let offsetPages = priorResultsLoadState.offset / priorResultsLoadState.per_page;
priorResultsLoadState.page += offsetPages;
priorResultsLoadState.pages += offsetPages;
priorResultsLoadState.total += priorResultsLoadState.offset;
priorResultsLoadState.offset = 0;
if (priorResultsLoadState.loading) {
return;
}
if (priorResultsLoadState.page >= priorResultsLoadState.pages) {
return; // Nothing more to load
}
// Load
priorResultsLoadState.loading = true
let url = new URL('/api/images', document.baseURI);
url.searchParams.append('page', priorResultsLoadState.initialized ? priorResultsLoadState.page + 1 : priorResultsLoadState.page);
url.searchParams.append('per_page', priorResultsLoadState.per_page);
fetch(url.href, {
method: 'GET',
headers: new Headers({'content-type': 'application/json'})
})
.then(response => response.json())
.then(data => {
priorResultsLoadState.page = data.page;
priorResultsLoadState.pages = data.pages;
priorResultsLoadState.per_page = data.per_page;
priorResultsLoadState.total = data.total;
data.items.forEach(function(dreamId, index) {
let src = 'api/images/' + dreamId;
fetch('/api/images/' + dreamId + '/metadata', {
method: 'GET',
headers: new Headers({'content-type': 'application/json'})
})
.then(response => response.json())
.then(metadata => {
let seed = metadata.seed || 0; // TODO: Parse old metadata
appendOutput(src, seed, metadata, true);
});
});
// Load until page is full
if (!priorResultsLoadState.initialized) {
if (document.body.scrollHeight <= window.innerHeight) {
loadPriorResults();
}
}
})
.finally(() => {
priorResultsLoadState.loading = false;
priorResultsLoadState.initialized = true;
});
}
function resetForm() {
var form = document.getElementById('generate-form');
form.querySelector('fieldset').removeAttribute('disabled');
}
function initProgress(totalSteps, showProgressImages) {
// TODO: Progress could theoretically come from multiple jobs at the same time (in the future)
let progressSectionEle = document.querySelector('#progress-section');
progressSectionEle.style.display = 'initial';
let progressEle = document.querySelector('#progress-bar');
progressEle.setAttribute('max', totalSteps);
let progressImageEle = document.querySelector('#progress-image');
progressImageEle.src = BLANK_IMAGE_URL;
progressImageEle.style.display = showProgressImages ? 'initial': 'none';
}
function setProgress(step, totalSteps, src) {
let progressEle = document.querySelector('#progress-bar');
progressEle.setAttribute('value', step);
if (src) {
let progressImageEle = document.querySelector('#progress-image');
progressImageEle.src = src;
}
}
function resetProgress(hide = true) {
if (hide) {
let progressSectionEle = document.querySelector('#progress-section');
progressSectionEle.style.display = 'none';
}
let progressEle = document.querySelector('#progress-bar');
progressEle.setAttribute('value', 0);
}
function toBase64(file) {
return new Promise((resolve, reject) => {
const r = new FileReader();
r.readAsDataURL(file);
r.onload = () => resolve(r.result);
r.onerror = (error) => reject(error);
});
}
function ondragdream(event) {
let dream = event.target.dataset.dream;
event.dataTransfer.setData("dream", dream);
}
function seedClick(event) {
// Get element
var image = event.target.closest('figure').querySelector('img');
var dream = JSON.parse(decodeURIComponent(image.dataset.dream));
let form = document.querySelector("#generate-form");
for (const [k, v] of new FormData(form)) {
if (k == 'initimg') { continue; }
let formElem = form.querySelector(`*[name=${k}]`);
formElem.value = dream[k] !== undefined ? dream[k] : formElem.defaultValue;
}
document.querySelector("#seed").value = dream.seed;
document.querySelector('#iterations').value = 1; // Reset to 1 iteration since we clicked a single image (not a full job)
// NOTE: leaving this manual for the user for now - it was very confusing with this behavior
// document.querySelector("#with_variations").value = variations || '';
// if (document.querySelector("#variation_amount").value <= 0) {
// document.querySelector("#variation_amount").value = 0.2;
// }
saveFields(document.querySelector("#generate-form"));
}
function appendOutput(src, seed, config, toEnd=false) {
let outputNode = document.createElement("figure");
let altText = seed.toString() + " | " + config.prompt;
// img needs width and height for lazy loading to work
// TODO: store the full config in a data attribute on the image?
const figureContents = `
<a href="${src}" target="_blank">
<img src="${src}"
alt="${altText}"
title="${altText}"
loading="lazy"
width="256"
height="256"
draggable="true"
ondragstart="ondragdream(event, this)"
data-dream="${encodeURIComponent(JSON.stringify(config))}"
data-dreamId="${encodeURIComponent(config.dreamId)}">
</a>
<figcaption onclick="seedClick(event, this)">${seed}</figcaption>
`;
outputNode.innerHTML = figureContents;
if (toEnd) {
document.querySelector("#results").append(outputNode);
} else {
document.querySelector("#results").prepend(outputNode);
}
document.querySelector("#no-results-message")?.remove();
}
function saveFields(form) {
for (const [k, v] of new FormData(form)) {
if (typeof v !== 'object') { // Don't save 'file' type
localStorage.setItem(k, v);
}
}
}
function loadFields(form) {
for (const [k, v] of new FormData(form)) {
const item = localStorage.getItem(k);
if (item != null) {
form.querySelector(`*[name=${k}]`).value = item;
}
}
}
function clearFields(form) {
localStorage.clear();
let prompt = form.prompt.value;
form.reset();
form.prompt.value = prompt;
}
const BLANK_IMAGE_URL = 'data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>';
async function generateSubmit(form) {
// Convert file data to base64
// TODO: Should probably uplaod files with formdata or something, and store them in the backend?
let formData = Object.fromEntries(new FormData(form));
if (!formData.enable_generate && !formData.enable_init_image) {
gen_label = document.querySelector("label[for=enable_generate]").innerHTML;
initimg_label = document.querySelector("label[for=enable_init_image]").innerHTML;
alert(`Error: one of "${gen_label}" or "${initimg_label}" must be set`);
}
formData.initimg_name = formData.initimg.name
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
// Evaluate all checkboxes
let checkboxes = form.querySelectorAll('input[type=checkbox]');
checkboxes.forEach(function (checkbox) {
if (checkbox.checked) {
formData[checkbox.name] = 'true';
}
});
let strength = formData.strength;
let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps;
let showProgressImages = formData.progress_images;
// Set enabling flags
// Initialize the progress bar
initProgress(totalSteps, showProgressImages);
// POST, use response to listen for events
fetch(form.action, {
method: form.method,
headers: new Headers({'content-type': 'application/json'}),
body: JSON.stringify(formData),
})
.then(response => response.json())
.then(data => {
var jobId = data.jobId;
socket.emit('join_room', { 'room': jobId });
});
form.querySelector('fieldset').setAttribute('disabled','');
}
function fieldSetEnableChecked(event) {
cb = event.target;
fields = cb.closest('fieldset');
fields.disabled = !cb.checked;
}
// Socket listeners
socket.on('job_started', (data) => {})
socket.on('dream_result', (data) => {
var jobId = data.jobId;
var dreamId = data.dreamId;
var dreamRequest = data.dreamRequest;
var src = 'api/images/' + dreamId;
priorResultsLoadState.offset += 1;
appendOutput(src, dreamRequest.seed, dreamRequest);
resetProgress(false);
})
socket.on('dream_progress', (data) => {
// TODO: it'd be nice if we could get a seed reported here, but the generator would need to be updated
var step = data.step;
var totalSteps = data.totalSteps;
var jobId = data.jobId;
var dreamId = data.dreamId;
var progressType = data.progressType
if (progressType === 'GENERATION') {
var src = data.hasProgressImage ?
'api/intermediates/' + dreamId + '/' + step
: null;
setProgress(step, totalSteps, src);
} else if (progressType === 'UPSCALING_STARTED') {
// step and totalSteps are used for upscale count on this message
document.getElementById("processing_cnt").textContent = step;
document.getElementById("processing_total").textContent = totalSteps;
document.getElementById("scaling-inprocess-message").style.display = "block";
} else if (progressType == 'UPSCALING_DONE') {
document.getElementById("scaling-inprocess-message").style.display = "none";
}
})
socket.on('job_canceled', (data) => {
resetForm();
resetProgress();
})
socket.on('job_done', (data) => {
jobId = data.jobId
socket.emit('leave_room', { 'room': jobId });
resetForm();
resetProgress();
})
window.onload = async () => {
document.querySelector("#prompt").addEventListener("keydown", (e) => {
if (e.key === "Enter" && !e.shiftKey) {
const form = e.target.form;
generateSubmit(form);
}
});
document.querySelector("#generate-form").addEventListener('submit', (e) => {
e.preventDefault();
const form = e.target;
generateSubmit(form);
});
document.querySelector("#generate-form").addEventListener('change', (e) => {
saveFields(e.target.form);
});
document.querySelector("#reset-seed").addEventListener('click', (e) => {
document.querySelector("#seed").value = 0;
saveFields(e.target.form);
});
document.querySelector("#reset-all").addEventListener('click', (e) => {
clearFields(e.target.form);
});
document.querySelector("#remove-image").addEventListener('click', (e) => {
initimg.value=null;
});
loadFields(document.querySelector("#generate-form"));
document.querySelector('#cancel-button').addEventListener('click', () => {
fetch('/api/cancel').catch(e => {
console.error(e);
});
});
document.documentElement.addEventListener('keydown', (e) => {
if (e.key === "Escape")
fetch('/api/cancel').catch(err => {
console.error(err);
});
});
if (!config.gfpgan_model_exists) {
document.querySelector("#gfpgan").style.display = 'none';
}
window.addEventListener("scroll", () => {
if ((window.innerHeight + window.pageYOffset) >= document.body.offsetHeight) {
loadPriorResults();
}
});
// Enable/disable forms by checkboxes
document.querySelectorAll("legend > input[type=checkbox]").forEach(function(cb) {
cb.addEventListener('change', fieldSetEnableChecked);
fieldSetEnableChecked({ target: cb})
});
// Load some of the previous results
loadPriorResults();
// Image drop/upload WIP
/*
let drop = document.getElementById('dropper');
function ondrop(event) {
let dreamData = event.dataTransfer.getData('dream');
if (dreamData) {
var dream = JSON.parse(decodeURIComponent(dreamData));
alert(dream.dreamId);
}
};
function ondragenter(event) {
event.preventDefault();
};
function ondragover(event) {
event.preventDefault();
};
function ondragleave(event) {
}
drop.addEventListener('drop', ondrop);
drop.addEventListener('dragenter', ondragenter);
drop.addEventListener('dragover', ondragover);
drop.addEventListener('dragleave', ondragleave);
*/
};

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 KiB

152
static/legacy_web/index.css Normal file
View File

@ -0,0 +1,152 @@
* {
font-family: 'Arial';
font-size: 100%;
}
body {
font-size: 1em;
}
textarea {
font-size: 0.95em;
}
header, form, #progress-section {
margin-left: auto;
margin-right: auto;
max-width: 1024px;
text-align: center;
}
fieldset {
border: none;
line-height: 2.2em;
}
select, input {
margin-right: 10px;
padding: 2px;
}
input[type=submit] {
background-color: #666;
color: white;
}
input[type=checkbox] {
margin-right: 0px;
width: 20px;
height: 20px;
vertical-align: middle;
}
input#seed {
margin-right: 0px;
}
div {
padding: 10px 10px 10px 10px;
}
header {
margin-bottom: 16px;
}
header h1 {
margin-bottom: 0;
font-size: 2em;
}
#search-box {
display: flex;
}
#scaling-inprocess-message {
font-weight: bold;
font-style: italic;
display: none;
}
#prompt {
flex-grow: 1;
padding: 5px 10px 5px 10px;
border: 1px solid #999;
outline: none;
}
#submit {
padding: 5px 10px 5px 10px;
border: 1px solid #999;
}
#reset-all, #remove-image {
margin-top: 12px;
font-size: 0.8em;
background-color: pink;
border: 1px solid #999;
border-radius: 4px;
}
#results {
text-align: center;
margin: auto;
padding-top: 10px;
}
#results figure {
display: inline-block;
margin: 10px;
}
#results figcaption {
font-size: 0.8em;
padding: 3px;
color: #888;
cursor: pointer;
}
#results img {
border-radius: 5px;
object-fit: cover;
}
#fieldset-config {
line-height:2em;
background-color: #F0F0F0;
}
input[type="number"] {
width: 60px;
}
#seed {
width: 150px;
}
button#reset-seed {
font-size: 1.7em;
background: #efefef;
border: 1px solid #999;
border-radius: 4px;
line-height: 0.8;
margin: 0 10px 0 0;
padding: 0 5px 3px;
vertical-align: middle;
}
label {
white-space: nowrap;
}
#progress-section {
display: none;
}
#progress-image {
width: 30vh;
height: 30vh;
}
#cancel-button {
cursor: pointer;
color: red;
}
#basic-parameters {
background-color: #EEEEEE;
}
#txt2img {
background-color: #DCDCDC;
}
#variations {
background-color: #EEEEEE;
}
#img2img {
background-color: #DCDCDC;
}
#gfpgan {
background-color: #EEEEEE;
}
#progress-section {
background-color: #F5F5F5;
}
.section-header {
text-align: left;
font-weight: bold;
padding: 0 0 0 0;
}
#no-results-message:not(:only-child) {
display: none;
}

View File

@ -0,0 +1,137 @@
<html lang="en">
<head>
<title>Stable Diffusion Dream Server</title>
<meta charset="utf-8">
<link rel="icon" type="image/x-icon" href="static/legacy_web/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="static/legacy_web/index.css">
<script src="config.js"></script>
<script src="static/legacy_web/index.js"></script>
</head>
<body>
<header>
<h1>Stable Diffusion Dream Server</h1>
<div id="about">
For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a>
</div>
</header>
<main>
<form id="generate-form" method="post" action="#">
<fieldset id="txt2img">
<div id="search-box">
<textarea rows="3" id="prompt" name="prompt"></textarea>
<input type="submit" id="submit" value="Generate">
</div>
</fieldset>
<fieldset id="fieldset-config">
<div class="section-header">Basic options</div>
<label for="iterations">Images to generate:</label>
<input value="1" type="number" id="iterations" name="iterations" size="4">
<label for="steps">Steps:</label>
<input value="50" type="number" id="steps" name="steps">
<label for="cfg_scale">Cfg Scale:</label>
<input value="7.5" type="number" id="cfg_scale" name="cfg_scale" step="any">
<label for="sampler_name">Sampler:</label>
<select id="sampler_name" name="sampler_name" value="k_lms">
<option value="ddim">DDIM</option>
<option value="plms">PLMS</option>
<option value="k_lms" selected>KLMS</option>
<option value="k_dpm_2">KDPM_2</option>
<option value="k_dpm_2_a">KDPM_2A</option>
<option value="k_dpmpp_2">KDPMPP_2</option>
<option value="k_dpmpp_2_a">KDPMPP_2A</option>
<option value="k_euler">KEULER</option>
<option value="k_euler_a">KEULER_A</option>
<option value="k_heun">KHEUN</option>
</select>
<input type="checkbox" name="seamless" id="seamless">
<label for="seamless">Seamless circular tiling</label>
<br>
<label title="Set to multiple of 64" for="width">Width:</label>
<select id="width" name="width" value="512">
<option value="64">64</option> <option value="128">128</option>
<option value="192">192</option> <option value="256">256</option>
<option value="320">320</option> <option value="384">384</option>
<option value="448">448</option> <option value="512" selected>512</option>
<option value="576">576</option> <option value="640">640</option>
<option value="704">704</option> <option value="768">768</option>
<option value="832">832</option> <option value="896">896</option>
<option value="960">960</option> <option value="1024">1024</option>
</select>
<label title="Set to multiple of 64" for="height">Height:</label>
<select id="height" name="height" value="512">
<option value="64">64</option> <option value="128">128</option>
<option value="192">192</option> <option value="256">256</option>
<option value="320">320</option> <option value="384">384</option>
<option value="448">448</option> <option value="512" selected>512</option>
<option value="576">576</option> <option value="640">640</option>
<option value="704">704</option> <option value="768">768</option>
<option value="832">832</option> <option value="896">896</option>
<option value="960">960</option> <option value="1024">1024</option>
</select>
<label title="Set to -1 for random seed" for="seed">Seed:</label>
<input value="-1" type="number" id="seed" name="seed">
<button type="button" id="reset-seed">&olarr;</button>
<input type="checkbox" name="progress_images" id="progress_images">
<label for="progress_images">Display in-progress images (slower)</label>
<div>
<label title="If > 0, adds thresholding to restrict values for k-diffusion samplers (0 disables)" for="threshold">Threshold:</label>
<input value="0" type="number" id="threshold" name="threshold" step="0.1" min="0">
<label title="Perlin: optional 0-1 value adds a percentage of perlin noise to the initial noise" for="perlin">Perlin:</label>
<input value="0" type="number" id="perlin" name="perlin" step="0.01" min="0" max="1">
<button type="button" id="reset-all">Reset to Defaults</button>
</div>
<span id="variations">
<label title="If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different." for="variation_amount">Variation amount (0 to disable):</label>
<input value="0" type="number" id="variation_amount" name="variation_amount" step="0.01" min="0" max="1">
<label title="list of variations to apply, in the format `seed:weight,seed:weight,..." for="with_variations">With variations (seed:weight,seed:weight,...):</label>
<input value="" type="text" id="with_variations" name="with_variations">
</span>
</fieldset>
<fieldset id="img2img">
<div class="section-header">Image-to-image options</div>
<label title="Upload an image to use img2img" for="initimg">Initial image:</label>
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
<button type="button" id="remove-image">Remove Image</button>
<br>
<label for="strength">Img2Img Strength:</label>
<input value="0.75" type="number" id="strength" name="strength" step="0.01" min="0" max="1">
<input type="checkbox" id="fit" name="fit" checked>
<label title="Rescale image to fit within requested width and height" for="fit">Fit to width/height</label>
</fieldset>
<fieldset id="gfpgan">
<div class="section-header">Post-processing options</div>
<label title="Strength of the gfpgan (face fixing) algorithm." for="facetool_strength">GPFGAN Strength (0 to disable):</label>
<input value="0.0" min="0" max="1" type="number" id="facetool_strength" name="facetool_strength" step="0.1">
<label title="Upscaling to perform using ESRGAN." for="upscale_level">Upscaling Level</label>
<select id="upscale_level" name="upscale_level" value="">
<option value="" selected>None</option>
<option value="2">2x</option>
<option value="4">4x</option>
</select>
<label title="Strength of the esrgan (upscaling) algorithm." for="upscale_strength">Upscale Strength:</label>
<input value="0.75" min="0" max="1" type="number" id="upscale_strength" name="upscale_strength" step="0.05">
</fieldset>
</form>
<br>
<section id="progress-section">
<div id="progress-container">
<progress id="progress-bar" value="0" max="1"></progress>
<span id="cancel-button" title="Cancel">&#10006;</span>
<br>
<img id="progress-image" src='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'>
<div id="scaling-inprocess-message">
<i><span>Postprocessing...</span><span id="processing_cnt">1/3</span></i>
</div>
</span>
</section>
<div id="results">
<div id="no-results-message">
<i><p>No results...</p></i>
</div>
</div>
</main>
</body>
</html>

213
static/legacy_web/index.js Normal file
View File

@ -0,0 +1,213 @@
function toBase64(file) {
return new Promise((resolve, reject) => {
const r = new FileReader();
r.readAsDataURL(file);
r.onload = () => resolve(r.result);
r.onerror = (error) => reject(error);
});
}
function appendOutput(src, seed, config) {
let outputNode = document.createElement("figure");
let variations = config.with_variations;
if (config.variation_amount > 0) {
variations = (variations ? variations + ',' : '') + seed + ':' + config.variation_amount;
}
let baseseed = (config.with_variations || config.variation_amount > 0) ? config.seed : seed;
let altText = baseseed + ' | ' + (variations ? variations + ' | ' : '') + config.prompt;
// img needs width and height for lazy loading to work
const figureContents = `
<a href="${src}" target="_blank">
<img src="${src}"
alt="${altText}"
title="${altText}"
loading="lazy"
width="256"
height="256">
</a>
<figcaption>${seed}</figcaption>
`;
outputNode.innerHTML = figureContents;
let figcaption = outputNode.querySelector('figcaption');
// Reload image config
figcaption.addEventListener('click', () => {
let form = document.querySelector("#generate-form");
for (const [k, v] of new FormData(form)) {
if (k == 'initimg') { continue; }
form.querySelector(`*[name=${k}]`).value = config[k];
}
document.querySelector("#seed").value = baseseed;
document.querySelector("#with_variations").value = variations || '';
if (document.querySelector("#variation_amount").value <= 0) {
document.querySelector("#variation_amount").value = 0.2;
}
saveFields(document.querySelector("#generate-form"));
});
document.querySelector("#results").prepend(outputNode);
}
function saveFields(form) {
for (const [k, v] of new FormData(form)) {
if (typeof v !== 'object') { // Don't save 'file' type
localStorage.setItem(k, v);
}
}
}
function loadFields(form) {
for (const [k, v] of new FormData(form)) {
const item = localStorage.getItem(k);
if (item != null) {
form.querySelector(`*[name=${k}]`).value = item;
}
}
}
function clearFields(form) {
localStorage.clear();
let prompt = form.prompt.value;
form.reset();
form.prompt.value = prompt;
}
const BLANK_IMAGE_URL = 'data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>';
async function generateSubmit(form) {
const prompt = document.querySelector("#prompt").value;
// Convert file data to base64
let formData = Object.fromEntries(new FormData(form));
formData.initimg_name = formData.initimg.name
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
let strength = formData.strength;
let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps;
let progressSectionEle = document.querySelector('#progress-section');
progressSectionEle.style.display = 'initial';
let progressEle = document.querySelector('#progress-bar');
progressEle.setAttribute('max', totalSteps);
let progressImageEle = document.querySelector('#progress-image');
progressImageEle.src = BLANK_IMAGE_URL;
progressImageEle.style.display = {}.hasOwnProperty.call(formData, 'progress_images') ? 'initial': 'none';
// Post as JSON, using Fetch streaming to get results
fetch(form.action, {
method: form.method,
body: JSON.stringify(formData),
}).then(async (response) => {
const reader = response.body.getReader();
let noOutputs = true;
while (true) {
let {value, done} = await reader.read();
value = new TextDecoder().decode(value);
if (done) {
progressSectionEle.style.display = 'none';
break;
}
for (let event of value.split('\n').filter(e => e !== '')) {
const data = JSON.parse(event);
if (data.event === 'result') {
noOutputs = false;
appendOutput(data.url, data.seed, data.config);
progressEle.setAttribute('value', 0);
progressEle.setAttribute('max', totalSteps);
} else if (data.event === 'upscaling-started') {
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
document.getElementById("scaling-inprocess-message").style.display = "block";
} else if (data.event === 'upscaling-done') {
document.getElementById("scaling-inprocess-message").style.display = "none";
} else if (data.event === 'step') {
progressEle.setAttribute('value', data.step);
if (data.url) {
progressImageEle.src = data.url;
}
} else if (data.event === 'canceled') {
// avoid alerting as if this were an error case
noOutputs = false;
}
}
}
// Re-enable form, remove no-results-message
form.querySelector('fieldset').removeAttribute('disabled');
document.querySelector("#prompt").value = prompt;
document.querySelector('progress').setAttribute('value', '0');
if (noOutputs) {
alert("Error occurred while generating.");
}
});
// Disable form while generating
form.querySelector('fieldset').setAttribute('disabled','');
document.querySelector("#prompt").value = `Generating: "${prompt}"`;
}
async function fetchRunLog() {
try {
let response = await fetch('/run_log.json')
const data = await response.json();
for(let item of data.run_log) {
appendOutput(item.url, item.seed, item);
}
} catch (e) {
console.error(e);
}
}
window.onload = async () => {
document.querySelector("#prompt").addEventListener("keydown", (e) => {
if (e.key === "Enter" && !e.shiftKey) {
const form = e.target.form;
generateSubmit(form);
}
});
document.querySelector("#generate-form").addEventListener('submit', (e) => {
e.preventDefault();
const form = e.target;
generateSubmit(form);
});
document.querySelector("#generate-form").addEventListener('change', (e) => {
saveFields(e.target.form);
});
document.querySelector("#reset-seed").addEventListener('click', (e) => {
document.querySelector("#seed").value = -1;
saveFields(e.target.form);
});
document.querySelector("#reset-all").addEventListener('click', (e) => {
clearFields(e.target.form);
});
document.querySelector("#remove-image").addEventListener('click', (e) => {
initimg.value=null;
});
loadFields(document.querySelector("#generate-form"));
document.querySelector('#cancel-button').addEventListener('click', () => {
fetch('/cancel').catch(e => {
console.error(e);
});
});
document.documentElement.addEventListener('keydown', (e) => {
if (e.key === "Escape")
fetch('/cancel').catch(err => {
console.error(err);
});
});
if (!config.gfpgan_model_exists) {
document.querySelector("#gfpgan").style.display = 'none';
}
await fetchRunLog()
};

View File

@ -21,12 +21,13 @@ def simple_graph():
def mock_services():
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
generate = None,
model_manager = None,
events = None,
images = None,
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
processor = DefaultInvocationProcessor(),
restoration = None,
)
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:

View File

@ -21,12 +21,13 @@ def simple_graph():
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
generate = None, # type: ignore
model_manager = None, # type: ignore
events = TestEventService(),
images = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
processor = DefaultInvocationProcessor(),
restoration = None,
)
@pytest.fixture()