mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove image generation node dependencies on generate.py (#2902)
# Remove node dependencies on generate.py This is a draft PR in which I am replacing `generate.py` with a cleaner, more structured interface to the underlying image generation routines. The basic code pattern to generate an image using the new API is this: ``` from invokeai.backend import ModelManager, Txt2Img, Img2Img manager = ModelManager('/data/lstein/invokeai-main/configs/models.yaml') model = manager.get_model('stable-diffusion-1.5') txt2img = Txt2Img(model) outputs = txt2img.generate(prompt='banana sushi', steps=12, scheduler='k_euler_a', iterations=5) # generate() returns an iterator for next_output in outputs: print(next_output.image, next_output.seed) outputs = Img2Img(model).generate(prompt='strawberry` sushi', init_img='./banana_sushi.png') output = next(outputs) output.image.save('strawberries.png') ``` ### model management The `ModelManager` handles model selection and initialization. Its `get_model()` method will return a `dict` with the following keys: `model`, `model_name`,`hash`, `width`, and `height`, where `model` is the actual StableDiffusionGeneratorPIpeline. If `get_model()` is called without a model name, it will return whatever is defined as the default in `models.yaml`, or the first entry if no default is designated. ### InvokeAIGenerator The abstract base class `InvokeAIGenerator` is subclassed into into `Txt2Img`, `Img2Img`, `Inpaint` and `Embiggen`. The constructor for these classes takes the model dict returned by `model_manager.get_model()` and optionally an `InvokeAIGeneratorBasicParams` object, which encapsulates all the parameters in common among `Txt2Img`, `Img2Img` etc. If you don't provide the basic params, a reasonable set of defaults will be chosen. Any of these parameters can be overridden at `generate()` time. These classes are defined in `invokeai.backend.generator`, but they are also exported by `invokeai.backend` as shown in the example below. ``` from invokeai.backend import InvokeAIGeneratorBasicParams, Img2Img params = InvokeAIGeneratorBasicParams( perlin = 0.15 steps = 30 scheduler = 'k_lms' ) img2img = Img2Img(model, params) outputs = img2img.generate(scheduler='k_heun') ``` Note that we were able to override the basic params in the call to `generate()` The `generate()` method will returns an iterator over a series of `InvokeAIGeneratorOutput` objects. These objects contain the PIL image, the seed, the model name and hash, and attributes for all the parameters used to generate the object (you can also get these as a dict). The `iterations` argument controls how many objects will be returned, defaulting to 1. Pass `None` to get an infinite iterator. Given the proposed use of `compel` to generate a templated series of prompts, I thought the API would benefit from a style that lets you loop over the output results indefinitely. I did consider returning a single `InvokeAIGeneratorOutput` object in the event that `iterations=1`, but I think it's dangerous for a method to return different types of result under different circumstances. Changing the model is as easy as this: ``` model = manager.get_model('inkspot-2.0`) txt2img = Txt2Img(model) ``` ### Node and legacy support With respect to `Nodes`, I have written `model_manager_initializer` and `restoration_services` modules that return `model_manager` and `restoration` services respectively. The latter is used by the face reconstruction and upscaling nodes. There is no longer any reference to `Generate` in the `app` tree. I have confirmed that `txt2img` and `img2img` work in the nodes client. I have not tested `embiggen` or `inpaint` yet. pytests are passing, with some warnings that I don't think are related to what I did. The legacy WebUI and CLI are still working off `Generate` (which has not yet been removed from the source tree) and fully functional. I've finished all the tasks on my TODO list: - [x] Update the pytests, which are failing due to dangling references to `generate` - [x] Rewrite the `reconstruct.py` and `upscale.py` nodes to call directly into the postprocessing modules rather than going through `Generate` - [x] Update the pytests, which are failing due to dangling references to `generate`
This commit is contained in:
commit
1aaad9336f
@ -4,7 +4,8 @@ import os
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
from ..services.generate_initializer import get_generate
|
from ..services.model_manager_initializer import get_model_manager
|
||||||
|
from ..services.restoration_services import RestorationServices
|
||||||
from ..services.graph import GraphExecutionState
|
from ..services.graph import GraphExecutionState
|
||||||
from ..services.image_storage import DiskImageStorage
|
from ..services.image_storage import DiskImageStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -37,18 +38,16 @@ class ApiDependencies:
|
|||||||
invoker: Invoker = None
|
invoker: Invoker = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize(args, config, event_handler_id: int):
|
def initialize(config, event_handler_id: int):
|
||||||
Globals.try_patchmatch = args.patchmatch
|
Globals.try_patchmatch = config.patchmatch
|
||||||
Globals.always_use_cpu = args.always_use_cpu
|
Globals.always_use_cpu = config.always_use_cpu
|
||||||
Globals.internet_available = args.internet_available and check_internet()
|
Globals.internet_available = config.internet_available and check_internet()
|
||||||
Globals.disable_xformers = not args.xformers
|
Globals.disable_xformers = not config.xformers
|
||||||
Globals.ckpt_convert = args.ckpt_convert
|
Globals.ckpt_convert = config.ckpt_convert
|
||||||
|
|
||||||
# TODO: Use a logger
|
# TODO: Use a logger
|
||||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||||
|
|
||||||
generate = get_generate(args, config)
|
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
|
|
||||||
output_folder = os.path.abspath(
|
output_folder = os.path.abspath(
|
||||||
@ -61,7 +60,7 @@ class ApiDependencies:
|
|||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
generate=generate,
|
model_manager=get_model_manager(config),
|
||||||
events=events,
|
events=events,
|
||||||
images=images,
|
images=images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
@ -69,6 +68,7 @@ class ApiDependencies:
|
|||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
),
|
),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
|
restoration=RestorationServices(config),
|
||||||
)
|
)
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
|
||||||
@ -53,11 +52,11 @@ config = {}
|
|||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
args = Args()
|
config = Args()
|
||||||
config = args.parse_args()
|
config.parse_args()
|
||||||
|
|
||||||
ApiDependencies.initialize(
|
ApiDependencies.initialize(
|
||||||
args=args, config=config, event_handler_id=event_handler_id
|
config=config, event_handler_id=event_handler_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,8 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra
|
|||||||
from .invocations import *
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.generate_initializer import get_generate
|
from .services.model_manager_initializer import get_model_manager
|
||||||
|
from .services.restoration_services import RestorationServices
|
||||||
from .services.graph import EdgeConnection, GraphExecutionState
|
from .services.graph import EdgeConnection, GraphExecutionState
|
||||||
from .services.image_storage import DiskImageStorage
|
from .services.image_storage import DiskImageStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -126,14 +127,9 @@ def invoke_all(context: CliContext):
|
|||||||
|
|
||||||
|
|
||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
args = Args()
|
config = Args()
|
||||||
config = args.parse_args()
|
config.parse_args()
|
||||||
|
model_manager = get_model_manager(config)
|
||||||
generate = get_generate(args, config)
|
|
||||||
|
|
||||||
# NOTE: load model on first use, uncomment to load at startup
|
|
||||||
# TODO: Make this a config option?
|
|
||||||
# generate.load_model()
|
|
||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
|
|
||||||
@ -145,7 +141,7 @@ def invoke_cli():
|
|||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
generate=generate,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
images=DiskImageStorage(output_folder),
|
images=DiskImageStorage(output_folder),
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
@ -153,6 +149,7 @@ def invoke_cli():
|
|||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
),
|
),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
|
restoration=RestorationServices(config),
|
||||||
)
|
)
|
||||||
|
|
||||||
invoker = Invoker(services)
|
invoker = Invoker(services)
|
||||||
|
@ -12,12 +12,12 @@ from ..services.image_storage import ImageType
|
|||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageField, ImageOutput
|
||||||
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
"ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun"
|
tuple(InvokeAIGenerator.schedulers())
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
class TextToImageInvocation(BaseInvocation):
|
class TextToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image using text2img."""
|
"""Generates an image using text2img."""
|
||||||
@ -57,19 +57,18 @@ class TextToImageInvocation(BaseInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
if self.model is None or self.model == "":
|
# (right now uses whatever current model is set in model manager)
|
||||||
self.model = context.services.generate.model_name
|
model= context.services.model_manager.get_model()
|
||||||
|
outputs = Txt2Img(model).generate(
|
||||||
# Set the model (if already cached, this does nothing)
|
|
||||||
context.services.generate.set_model(self.model)
|
|
||||||
|
|
||||||
results = context.services.generate.prompt2image(
|
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
step_callback=step_callback,
|
step_callback=step_callback,
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt"}
|
exclude={"prompt"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
)
|
)
|
||||||
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||||
|
# each time it is called. We only need the first one.
|
||||||
|
generate_output = next(outputs)
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now and ignore the seed
|
# Results are image and seed, unwrap for now and ignore the seed
|
||||||
# TODO: pre-seed?
|
# TODO: pre-seed?
|
||||||
@ -78,7 +77,7 @@ class TextToImageInvocation(BaseInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, results[0][0])
|
context.services.images.save(image_type, image_name, generate_output.image)
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
)
|
||||||
@ -115,23 +114,20 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
if self.model is None or self.model == "":
|
model = context.services.model_manager.get_model()
|
||||||
self.model = context.services.generate.model_name
|
generator_output = next(
|
||||||
|
Img2Img(model).generate(
|
||||||
# Set the model (if already cached, this does nothing)
|
prompt=self.prompt,
|
||||||
context.services.generate.set_model(self.model)
|
init_img=image,
|
||||||
|
init_mask=mask,
|
||||||
results = context.services.generate.prompt2image(
|
step_callback=step_callback,
|
||||||
prompt=self.prompt,
|
**self.dict(
|
||||||
init_img=image,
|
exclude={"prompt", "image", "mask"}
|
||||||
init_mask=mask,
|
), # Shorthand for passing all of the parameters above manually
|
||||||
step_callback=step_callback,
|
)
|
||||||
**self.dict(
|
|
||||||
exclude={"prompt", "image", "mask"}
|
|
||||||
), # Shorthand for passing all of the parameters above manually
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result_image = results[0][0]
|
result_image = generator_output.image
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now and ignore the seed
|
# Results are image and seed, unwrap for now and ignore the seed
|
||||||
# TODO: pre-seed?
|
# TODO: pre-seed?
|
||||||
@ -145,7 +141,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InpaintInvocation(ImageToImageInvocation):
|
class InpaintInvocation(ImageToImageInvocation):
|
||||||
"""Generates an image using inpaint."""
|
"""Generates an image using inpaint."""
|
||||||
|
|
||||||
@ -180,23 +175,20 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
if self.model is None or self.model == "":
|
manager = context.services.model_manager.get_model()
|
||||||
self.model = context.services.generate.model_name
|
generator_output = next(
|
||||||
|
Inpaint(model).generate(
|
||||||
# Set the model (if already cached, this does nothing)
|
prompt=self.prompt,
|
||||||
context.services.generate.set_model(self.model)
|
init_img=image,
|
||||||
|
init_mask=mask,
|
||||||
results = context.services.generate.prompt2image(
|
step_callback=step_callback,
|
||||||
prompt=self.prompt,
|
**self.dict(
|
||||||
init_img=image,
|
exclude={"prompt", "image", "mask"}
|
||||||
init_mask=mask,
|
), # Shorthand for passing all of the parameters above manually
|
||||||
step_callback=step_callback,
|
)
|
||||||
**self.dict(
|
|
||||||
exclude={"prompt", "image", "mask"}
|
|
||||||
), # Shorthand for passing all of the parameters above manually
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result_image = results[0][0]
|
result_image = generator_output.image
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now and ignore the seed
|
# Results are image and seed, unwrap for now and ignore the seed
|
||||||
# TODO: pre-seed?
|
# TODO: pre-seed?
|
||||||
|
@ -8,7 +8,6 @@ from ..services.invocation_services import InvocationServices
|
|||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageField, ImageOutput
|
||||||
|
|
||||||
|
|
||||||
class RestoreFaceInvocation(BaseInvocation):
|
class RestoreFaceInvocation(BaseInvocation):
|
||||||
"""Restores faces in an image."""
|
"""Restores faces in an image."""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
@ -23,7 +22,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
results = context.services.generate.upscale_and_reconstruct(
|
results = context.services.restoration.upscale_and_reconstruct(
|
||||||
image_list=[[image, 0]],
|
image_list=[[image, 0]],
|
||||||
upscale=None,
|
upscale=None,
|
||||||
strength=self.strength, # GFPGAN strength
|
strength=self.strength, # GFPGAN strength
|
||||||
|
@ -26,7 +26,7 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
results = context.services.generate.upscale_and_reconstruct(
|
results = context.services.restoration.upscale_and_reconstruct(
|
||||||
image_list=[[image, 0]],
|
image_list=[[image, 0]],
|
||||||
upscale=(self.level, self.strength),
|
upscale=(self.level, self.strength),
|
||||||
strength=0.0, # GFPGAN strength
|
strength=0.0, # GFPGAN strength
|
||||||
|
@ -1,255 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from argparse import Namespace
|
|
||||||
|
|
||||||
import invokeai.version
|
|
||||||
from invokeai.backend import Generate, ModelManager
|
|
||||||
|
|
||||||
from ...backend import Globals
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
|
|
||||||
def get_generate(args, config) -> Generate:
|
|
||||||
if not args.conf:
|
|
||||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
|
||||||
if not os.path.exists(config_file):
|
|
||||||
report_model_error(
|
|
||||||
args, FileNotFoundError(f"The file {config_file} could not be found.")
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
|
||||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
|
||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
|
||||||
# when the frozen CLIP tokenizer is imported
|
|
||||||
import transformers # type: ignore
|
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
import diffusers
|
|
||||||
|
|
||||||
diffusers.logging.set_verbosity_error()
|
|
||||||
|
|
||||||
# Loading Face Restoration and ESRGAN Modules
|
|
||||||
gfpgan, codeformer, esrgan = load_face_restoration(args)
|
|
||||||
|
|
||||||
# normalize the config directory relative to root
|
|
||||||
if not os.path.isabs(args.conf):
|
|
||||||
args.conf = os.path.normpath(os.path.join(Globals.root, args.conf))
|
|
||||||
|
|
||||||
if args.embeddings:
|
|
||||||
if not os.path.isabs(args.embedding_path):
|
|
||||||
embedding_path = os.path.normpath(
|
|
||||||
os.path.join(Globals.root, args.embedding_path)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
embedding_path = args.embedding_path
|
|
||||||
else:
|
|
||||||
embedding_path = None
|
|
||||||
|
|
||||||
# migrate legacy models
|
|
||||||
ModelManager.migrate_models()
|
|
||||||
|
|
||||||
# load the infile as a list of lines
|
|
||||||
if args.infile:
|
|
||||||
try:
|
|
||||||
if os.path.isfile(args.infile):
|
|
||||||
infile = open(args.infile, "r", encoding="utf-8")
|
|
||||||
elif args.infile == "-": # stdin
|
|
||||||
infile = sys.stdin
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"{args.infile} not found.")
|
|
||||||
except (FileNotFoundError, IOError) as e:
|
|
||||||
print(f"{e}. Aborting.")
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
# creating a Generate object:
|
|
||||||
try:
|
|
||||||
gen = Generate(
|
|
||||||
conf=args.conf,
|
|
||||||
model=args.model,
|
|
||||||
sampler_name=args.sampler_name,
|
|
||||||
embedding_path=embedding_path,
|
|
||||||
full_precision=args.full_precision,
|
|
||||||
precision=args.precision,
|
|
||||||
gfpgan=gfpgan,
|
|
||||||
codeformer=codeformer,
|
|
||||||
esrgan=esrgan,
|
|
||||||
free_gpu_mem=args.free_gpu_mem,
|
|
||||||
safety_checker=args.safety_checker,
|
|
||||||
max_loaded_models=args.max_loaded_models,
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
|
||||||
report_model_error(opt, e)
|
|
||||||
except (IOError, KeyError) as e:
|
|
||||||
print(f"{e}. Aborting.")
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
if args.seamless:
|
|
||||||
print(">> changed to seamless tiling mode")
|
|
||||||
|
|
||||||
# preload the model
|
|
||||||
try:
|
|
||||||
gen.load_model()
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
report_model_error(args, e)
|
|
||||||
|
|
||||||
# try to autoconvert new models
|
|
||||||
# autoimport new .ckpt files
|
|
||||||
if path := args.autoconvert:
|
|
||||||
gen.model_manager.autoconvert_weights(
|
|
||||||
conf_path=args.conf,
|
|
||||||
weights_directory=path,
|
|
||||||
)
|
|
||||||
|
|
||||||
return gen
|
|
||||||
|
|
||||||
|
|
||||||
def load_face_restoration(opt):
|
|
||||||
try:
|
|
||||||
gfpgan, codeformer, esrgan = None, None, None
|
|
||||||
if opt.restore or opt.esrgan:
|
|
||||||
from invokeai.backend.restoration import Restoration
|
|
||||||
|
|
||||||
restoration = Restoration()
|
|
||||||
if opt.restore:
|
|
||||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
|
||||||
opt.gfpgan_model_path
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(">> Face restoration disabled")
|
|
||||||
if opt.esrgan:
|
|
||||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
|
||||||
else:
|
|
||||||
print(">> Upscaling disabled")
|
|
||||||
else:
|
|
||||||
print(">> Face restoration and upscaling disabled")
|
|
||||||
except (ModuleNotFoundError, ImportError):
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
|
||||||
return gfpgan, codeformer, esrgan
|
|
||||||
|
|
||||||
|
|
||||||
def report_model_error(opt: Namespace, e: Exception):
|
|
||||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
|
||||||
print(
|
|
||||||
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
|
||||||
)
|
|
||||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
|
||||||
if yes_to_all:
|
|
||||||
print(
|
|
||||||
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = input(
|
|
||||||
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
|
|
||||||
)
|
|
||||||
if response.startswith(("n", "N")):
|
|
||||||
return
|
|
||||||
|
|
||||||
print("invokeai-configure is launching....\n")
|
|
||||||
|
|
||||||
# Match arguments that were set on the CLI
|
|
||||||
# only the arguments accepted by the configuration script are parsed
|
|
||||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
|
||||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
|
||||||
previous_args = sys.argv
|
|
||||||
sys.argv = ["invokeai-configure"]
|
|
||||||
sys.argv.extend(root_dir)
|
|
||||||
sys.argv.extend(config)
|
|
||||||
if yes_to_all is not None:
|
|
||||||
for arg in yes_to_all.split():
|
|
||||||
sys.argv.append(arg)
|
|
||||||
|
|
||||||
from invokeai.frontend.install import invokeai_configure
|
|
||||||
|
|
||||||
invokeai_configure()
|
|
||||||
# TODO: Figure out how to restart
|
|
||||||
# print('** InvokeAI will now restart')
|
|
||||||
# sys.argv = previous_args
|
|
||||||
# main() # would rather do a os.exec(), but doesn't exist?
|
|
||||||
# sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
# Temporary initializer for Generate until we migrate off of it
|
|
||||||
def old_get_generate(args, config) -> Generate:
|
|
||||||
# TODO: Remove the need for globals
|
|
||||||
from invokeai.backend.globals import Globals
|
|
||||||
|
|
||||||
# alert - setting globals here
|
|
||||||
Globals.root = os.path.expanduser(
|
|
||||||
args.root_dir or os.environ.get("INVOKEAI_ROOT") or os.path.abspath(".")
|
|
||||||
)
|
|
||||||
Globals.try_patchmatch = args.patchmatch
|
|
||||||
|
|
||||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
|
||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
|
||||||
# when the frozen CLIP tokenizer is imported
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
|
|
||||||
# Loading Face Restoration and ESRGAN Modules
|
|
||||||
gfpgan, codeformer, esrgan = None, None, None
|
|
||||||
try:
|
|
||||||
if config.restore or config.esrgan:
|
|
||||||
from ldm.invoke.restoration import Restoration
|
|
||||||
|
|
||||||
restoration = Restoration()
|
|
||||||
if config.restore:
|
|
||||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
|
||||||
config.gfpgan_model_path
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(">> Face restoration disabled")
|
|
||||||
if config.esrgan:
|
|
||||||
esrgan = restoration.load_esrgan(config.esrgan_bg_tile)
|
|
||||||
else:
|
|
||||||
print(">> Upscaling disabled")
|
|
||||||
else:
|
|
||||||
print(">> Face restoration and upscaling disabled")
|
|
||||||
except (ModuleNotFoundError, ImportError):
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
|
||||||
|
|
||||||
# normalize the config directory relative to root
|
|
||||||
if not os.path.isabs(config.conf):
|
|
||||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
|
||||||
|
|
||||||
if config.embeddings:
|
|
||||||
if not os.path.isabs(config.embedding_path):
|
|
||||||
embedding_path = os.path.normpath(
|
|
||||||
os.path.join(Globals.root, config.embedding_path)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
embedding_path = None
|
|
||||||
|
|
||||||
# TODO: lazy-initialize this by wrapping it
|
|
||||||
try:
|
|
||||||
generate = Generate(
|
|
||||||
conf=config.conf,
|
|
||||||
model=config.model,
|
|
||||||
sampler_name=config.sampler_name,
|
|
||||||
embedding_path=embedding_path,
|
|
||||||
full_precision=config.full_precision,
|
|
||||||
precision=config.precision,
|
|
||||||
gfpgan=gfpgan,
|
|
||||||
codeformer=codeformer,
|
|
||||||
esrgan=esrgan,
|
|
||||||
free_gpu_mem=config.free_gpu_mem,
|
|
||||||
safety_checker=config.safety_checker,
|
|
||||||
max_loaded_models=config.max_loaded_models,
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, TypeError, AssertionError):
|
|
||||||
# emergency_model_reconfigure() # TODO?
|
|
||||||
sys.exit(-1)
|
|
||||||
except (IOError, KeyError) as e:
|
|
||||||
print(f"{e}. Aborting.")
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
generate.free_gpu_mem = config.free_gpu_mem
|
|
||||||
|
|
||||||
return generate
|
|
@ -1,36 +1,39 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
from invokeai.backend import Generate
|
from invokeai.backend import ModelManager
|
||||||
|
|
||||||
from .events import EventServiceBase
|
from .events import EventServiceBase
|
||||||
from .image_storage import ImageStorageBase
|
from .image_storage import ImageStorageBase
|
||||||
|
from .restoration_services import RestorationServices
|
||||||
from .invocation_queue import InvocationQueueABC
|
from .invocation_queue import InvocationQueueABC
|
||||||
from .item_storage import ItemStorageABC
|
from .item_storage import ItemStorageABC
|
||||||
|
|
||||||
|
|
||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
generate: Generate # TODO: wrap Generate, or split it up from model?
|
|
||||||
events: EventServiceBase
|
events: EventServiceBase
|
||||||
images: ImageStorageBase
|
images: ImageStorageBase
|
||||||
queue: InvocationQueueABC
|
queue: InvocationQueueABC
|
||||||
|
model_manager: ModelManager
|
||||||
|
restoration: RestorationServices
|
||||||
|
|
||||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
generate: Generate,
|
model_manager: ModelManager,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
images: ImageStorageBase,
|
images: ImageStorageBase,
|
||||||
queue: InvocationQueueABC,
|
queue: InvocationQueueABC,
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
|
restoration: RestorationServices,
|
||||||
):
|
):
|
||||||
self.generate = generate
|
self.model_manager = model_manager
|
||||||
self.events = events
|
self.events = events
|
||||||
self.images = images
|
self.images = images
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
self.restoration = restoration
|
||||||
|
120
invokeai/app/services/model_manager_initializer.py
Normal file
120
invokeai/app/services/model_manager_initializer.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
from argparse import Namespace
|
||||||
|
from invokeai.backend import Args
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import invokeai.version
|
||||||
|
from ...backend import ModelManager
|
||||||
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
|
from ...backend import Globals
|
||||||
|
|
||||||
|
# TODO: Replace with an abstract class base ModelManagerBase
|
||||||
|
def get_model_manager(config: Args) -> ModelManager:
|
||||||
|
if not config.conf:
|
||||||
|
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||||
|
if not os.path.exists(config_file):
|
||||||
|
report_model_error(
|
||||||
|
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||||
|
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||||
|
|
||||||
|
# these two lines prevent a horrible warning message from appearing
|
||||||
|
# when the frozen CLIP tokenizer is imported
|
||||||
|
import transformers # type: ignore
|
||||||
|
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
import diffusers
|
||||||
|
|
||||||
|
diffusers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# normalize the config directory relative to root
|
||||||
|
if not os.path.isabs(config.conf):
|
||||||
|
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
||||||
|
|
||||||
|
if config.embeddings:
|
||||||
|
if not os.path.isabs(config.embedding_path):
|
||||||
|
embedding_path = os.path.normpath(
|
||||||
|
os.path.join(Globals.root, config.embedding_path)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embedding_path = config.embedding_path
|
||||||
|
else:
|
||||||
|
embedding_path = None
|
||||||
|
|
||||||
|
# migrate legacy models
|
||||||
|
ModelManager.migrate_models()
|
||||||
|
|
||||||
|
# creating the model manager
|
||||||
|
try:
|
||||||
|
device = torch.device(choose_torch_device())
|
||||||
|
precision = 'float16' if config.precision=='float16' \
|
||||||
|
else 'float32' if config.precision=='float32' \
|
||||||
|
else choose_precision(device)
|
||||||
|
|
||||||
|
model_manager = ModelManager(
|
||||||
|
OmegaConf.load(config.conf),
|
||||||
|
precision=precision,
|
||||||
|
device_type=device,
|
||||||
|
max_loaded_models=config.max_loaded_models,
|
||||||
|
embedding_path = Path(embedding_path),
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
|
report_model_error(config, e)
|
||||||
|
except (IOError, KeyError) as e:
|
||||||
|
print(f"{e}. Aborting.")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
# try to autoconvert new models
|
||||||
|
# autoimport new .ckpt files
|
||||||
|
if path := config.autoconvert:
|
||||||
|
model_manager.autoconvert_weights(
|
||||||
|
conf_path=config.conf,
|
||||||
|
weights_directory=path,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_manager
|
||||||
|
|
||||||
|
def report_model_error(opt: Namespace, e: Exception):
|
||||||
|
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||||
|
print(
|
||||||
|
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||||
|
)
|
||||||
|
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||||
|
if yes_to_all:
|
||||||
|
print(
|
||||||
|
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = input(
|
||||||
|
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
|
||||||
|
)
|
||||||
|
if response.startswith(("n", "N")):
|
||||||
|
return
|
||||||
|
|
||||||
|
print("invokeai-configure is launching....\n")
|
||||||
|
|
||||||
|
# Match arguments that were set on the CLI
|
||||||
|
# only the arguments accepted by the configuration script are parsed
|
||||||
|
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||||
|
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||||
|
previous_config = sys.argv
|
||||||
|
sys.argv = ["invokeai-configure"]
|
||||||
|
sys.argv.extend(root_dir)
|
||||||
|
sys.argv.extend(config.to_dict())
|
||||||
|
if yes_to_all is not None:
|
||||||
|
for arg in yes_to_all.split():
|
||||||
|
sys.argv.append(arg)
|
||||||
|
|
||||||
|
from invokeai.frontend.install import invokeai_configure
|
||||||
|
|
||||||
|
invokeai_configure()
|
||||||
|
# TODO: Figure out how to restart
|
||||||
|
# print('** InvokeAI will now restart')
|
||||||
|
# sys.argv = previous_args
|
||||||
|
# main() # would rather do a os.exec(), but doesn't exist?
|
||||||
|
# sys.exit(0)
|
109
invokeai/app/services/restoration_services.py
Normal file
109
invokeai/app/services/restoration_services.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import torch
|
||||||
|
from ...backend.restoration import Restoration
|
||||||
|
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||||
|
|
||||||
|
# This should be a real base class for postprocessing functions,
|
||||||
|
# but right now we just instantiate the existing gfpgan, esrgan
|
||||||
|
# and codeformer functions.
|
||||||
|
class RestorationServices:
|
||||||
|
'''Face restoration and upscaling'''
|
||||||
|
|
||||||
|
def __init__(self,args):
|
||||||
|
try:
|
||||||
|
gfpgan, codeformer, esrgan = None, None, None
|
||||||
|
if args.restore or args.esrgan:
|
||||||
|
restoration = Restoration()
|
||||||
|
if args.restore:
|
||||||
|
gfpgan, codeformer = restoration.load_face_restore_models(
|
||||||
|
args.gfpgan_model_path
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(">> Face restoration disabled")
|
||||||
|
if args.esrgan:
|
||||||
|
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||||
|
else:
|
||||||
|
print(">> Upscaling disabled")
|
||||||
|
else:
|
||||||
|
print(">> Face restoration and upscaling disabled")
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||||
|
self.device = torch.device(choose_torch_device())
|
||||||
|
self.gfpgan = gfpgan
|
||||||
|
self.codeformer = codeformer
|
||||||
|
self.esrgan = esrgan
|
||||||
|
|
||||||
|
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||||
|
# esrgan upscaling
|
||||||
|
# TO DO: refactor into separate methods
|
||||||
|
def upscale_and_reconstruct(
|
||||||
|
self,
|
||||||
|
image_list,
|
||||||
|
facetool="gfpgan",
|
||||||
|
upscale=None,
|
||||||
|
upscale_denoise_str=0.75,
|
||||||
|
strength=0.0,
|
||||||
|
codeformer_fidelity=0.75,
|
||||||
|
save_original=False,
|
||||||
|
image_callback=None,
|
||||||
|
prefix=None,
|
||||||
|
):
|
||||||
|
results = []
|
||||||
|
for r in image_list:
|
||||||
|
image, seed = r
|
||||||
|
try:
|
||||||
|
if strength > 0:
|
||||||
|
if self.gfpgan is not None or self.codeformer is not None:
|
||||||
|
if facetool == "gfpgan":
|
||||||
|
if self.gfpgan is None:
|
||||||
|
print(
|
||||||
|
">> GFPGAN not found. Face restoration is disabled."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image = self.gfpgan.process(image, strength, seed)
|
||||||
|
if facetool == "codeformer":
|
||||||
|
if self.codeformer is None:
|
||||||
|
print(
|
||||||
|
">> CodeFormer not found. Face restoration is disabled."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cf_device = (
|
||||||
|
CPU_DEVICE if self.device == MPS_DEVICE else self.device
|
||||||
|
)
|
||||||
|
image = self.codeformer.process(
|
||||||
|
image=image,
|
||||||
|
strength=strength,
|
||||||
|
device=cf_device,
|
||||||
|
seed=seed,
|
||||||
|
fidelity=codeformer_fidelity,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(">> Face Restoration is disabled.")
|
||||||
|
if upscale is not None:
|
||||||
|
if self.esrgan is not None:
|
||||||
|
if len(upscale) < 2:
|
||||||
|
upscale.append(0.75)
|
||||||
|
image = self.esrgan.process(
|
||||||
|
image,
|
||||||
|
upscale[1],
|
||||||
|
seed,
|
||||||
|
int(upscale[0]),
|
||||||
|
denoise_str=upscale_denoise_str,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if image_callback is not None:
|
||||||
|
image_callback(image, seed, upscaled=True, use_prefix=prefix)
|
||||||
|
else:
|
||||||
|
r[0] = image
|
||||||
|
|
||||||
|
results.append([image, seed])
|
||||||
|
|
||||||
|
return results
|
@ -2,6 +2,15 @@
|
|||||||
Initialization file for invokeai.backend
|
Initialization file for invokeai.backend
|
||||||
"""
|
"""
|
||||||
from .generate import Generate
|
from .generate import Generate
|
||||||
|
from .generator import (
|
||||||
|
InvokeAIGeneratorBasicParams,
|
||||||
|
InvokeAIGenerator,
|
||||||
|
InvokeAIGeneratorOutput,
|
||||||
|
Txt2Img,
|
||||||
|
Img2Img,
|
||||||
|
Inpaint
|
||||||
|
)
|
||||||
from .model_management import ModelManager
|
from .model_management import ModelManager
|
||||||
|
from .safety_checker import SafetyChecker
|
||||||
from .args import Args
|
from .args import Args
|
||||||
from .globals import Globals
|
from .globals import Globals
|
||||||
|
@ -25,18 +25,19 @@ from accelerate.utils import set_seed
|
|||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from .args import metadata_from_png
|
from .args import metadata_from_png
|
||||||
from .generator import infill_methods
|
from .generator import infill_methods
|
||||||
from .globals import Globals, global_cache_dir
|
from .globals import Globals, global_cache_dir
|
||||||
from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding
|
from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding
|
||||||
from .model_management import ModelManager
|
from .model_management import ModelManager
|
||||||
|
from .safety_checker import SafetyChecker
|
||||||
from .prompting import get_uc_and_c_and_ec
|
from .prompting import get_uc_and_c_and_ec
|
||||||
from .prompting.conditioning import log_tokenization
|
from .prompting.conditioning import log_tokenization
|
||||||
from .stable_diffusion import HuggingFaceConceptsLibrary
|
from .stable_diffusion import HuggingFaceConceptsLibrary
|
||||||
from .util import choose_precision, choose_torch_device
|
from .util import choose_precision, choose_torch_device
|
||||||
|
|
||||||
|
|
||||||
def fix_func(orig):
|
def fix_func(orig):
|
||||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
|
||||||
@ -222,6 +223,7 @@ class Generate:
|
|||||||
self.precision,
|
self.precision,
|
||||||
max_loaded_models=max_loaded_models,
|
max_loaded_models=max_loaded_models,
|
||||||
sequential_offload=self.free_gpu_mem,
|
sequential_offload=self.free_gpu_mem,
|
||||||
|
embedding_path=Path(self.embedding_path),
|
||||||
)
|
)
|
||||||
# don't accept invalid models
|
# don't accept invalid models
|
||||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||||
@ -244,31 +246,8 @@ class Generate:
|
|||||||
|
|
||||||
# load safety checker if requested
|
# load safety checker if requested
|
||||||
if safety_checker:
|
if safety_checker:
|
||||||
try:
|
print(">> Initializing NSFW checker")
|
||||||
print(">> Initializing NSFW checker")
|
self.safety_checker = SafetyChecker(self.device)
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|
||||||
StableDiffusionSafetyChecker,
|
|
||||||
)
|
|
||||||
from transformers import AutoFeatureExtractor
|
|
||||||
|
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
|
||||||
safety_model_path = global_cache_dir("hub")
|
|
||||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
|
||||||
safety_model_id,
|
|
||||||
local_files_only=True,
|
|
||||||
cache_dir=safety_model_path,
|
|
||||||
)
|
|
||||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
||||||
safety_model_id,
|
|
||||||
local_files_only=True,
|
|
||||||
cache_dir=safety_model_path,
|
|
||||||
)
|
|
||||||
self.safety_checker.to(self.device)
|
|
||||||
except Exception:
|
|
||||||
print(
|
|
||||||
"** An error was encountered while installing the safety checker:"
|
|
||||||
)
|
|
||||||
print(traceback.format_exc())
|
|
||||||
else:
|
else:
|
||||||
print(">> NSFW checker is disabled")
|
print(">> NSFW checker is disabled")
|
||||||
|
|
||||||
@ -523,15 +502,6 @@ class Generate:
|
|||||||
generator.set_variation(self.seed, variation_amount, with_variations)
|
generator.set_variation(self.seed, variation_amount, with_variations)
|
||||||
generator.use_mps_noise = use_mps_noise
|
generator.use_mps_noise = use_mps_noise
|
||||||
|
|
||||||
checker = (
|
|
||||||
{
|
|
||||||
"checker": self.safety_checker,
|
|
||||||
"extractor": self.safety_feature_extractor,
|
|
||||||
}
|
|
||||||
if self.safety_checker
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
results = generator.generate(
|
results = generator.generate(
|
||||||
prompt,
|
prompt,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
@ -558,7 +528,7 @@ class Generate:
|
|||||||
embiggen_strength=embiggen_strength,
|
embiggen_strength=embiggen_strength,
|
||||||
inpaint_replace=inpaint_replace,
|
inpaint_replace=inpaint_replace,
|
||||||
mask_blur_radius=mask_blur_radius,
|
mask_blur_radius=mask_blur_radius,
|
||||||
safety_checker=checker,
|
safety_checker=self.safety_checker,
|
||||||
seam_size=seam_size,
|
seam_size=seam_size,
|
||||||
seam_blur=seam_blur,
|
seam_blur=seam_blur,
|
||||||
seam_strength=seam_strength,
|
seam_strength=seam_strength,
|
||||||
@ -940,18 +910,6 @@ class Generate:
|
|||||||
self.generators = {}
|
self.generators = {}
|
||||||
|
|
||||||
set_seed(random.randrange(0, np.iinfo(np.uint32).max))
|
set_seed(random.randrange(0, np.iinfo(np.uint32).max))
|
||||||
if self.embedding_path is not None:
|
|
||||||
print(f">> Loading embeddings from {self.embedding_path}")
|
|
||||||
for root, _, files in os.walk(self.embedding_path):
|
|
||||||
for name in files:
|
|
||||||
ti_path = os.path.join(root, name)
|
|
||||||
self.model.textual_inversion_manager.load_textual_inversion(
|
|
||||||
ti_path, defer_injecting_tokens=True
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f'>> Textual inversion triggers: {", ".join(sorted(self.model.textual_inversion_manager.get_all_trigger_strings()))}'
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self._set_scheduler() # requires self.model_name to be set first
|
self._set_scheduler() # requires self.model_name to be set first
|
||||||
return self.model
|
return self.model
|
||||||
@ -998,7 +956,7 @@ class Generate:
|
|||||||
):
|
):
|
||||||
results = []
|
results = []
|
||||||
for r in image_list:
|
for r in image_list:
|
||||||
image, seed = r
|
image, seed, _ = r
|
||||||
try:
|
try:
|
||||||
if strength > 0:
|
if strength > 0:
|
||||||
if self.gfpgan is not None or self.codeformer is not None:
|
if self.gfpgan is not None or self.codeformer is not None:
|
||||||
|
@ -1,5 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for the invokeai.generator package
|
Initialization file for the invokeai.generator package
|
||||||
"""
|
"""
|
||||||
from .base import Generator
|
from .base import (
|
||||||
|
InvokeAIGenerator,
|
||||||
|
InvokeAIGeneratorBasicParams,
|
||||||
|
InvokeAIGeneratorOutput,
|
||||||
|
Txt2Img,
|
||||||
|
Img2Img,
|
||||||
|
Inpaint,
|
||||||
|
Generator,
|
||||||
|
)
|
||||||
from .inpaint import infill_methods
|
from .inpaint import infill_methods
|
||||||
|
@ -4,11 +4,15 @@ including img2img, txt2img, and inpaint
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import dataclasses
|
||||||
|
import diffusers
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
|
from abc import ABCMeta
|
||||||
|
from argparse import Namespace
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -17,13 +21,258 @@ from PIL import Image, ImageChops, ImageFilter
|
|||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
from typing import List, Iterator, Type
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
import invokeai.assets.web as web_assets
|
from ..image_util import configure_model_padding
|
||||||
from ..util.util import rand_perlin_2d
|
from ..util.util import rand_perlin_2d
|
||||||
|
from ..safety_checker import SafetyChecker
|
||||||
|
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||||
|
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
CAUTION_IMG = "caution.png"
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InvokeAIGeneratorBasicParams:
|
||||||
|
seed: int=None
|
||||||
|
width: int=512
|
||||||
|
height: int=512
|
||||||
|
cfg_scale: int=7.5
|
||||||
|
steps: int=20
|
||||||
|
ddim_eta: float=0.0
|
||||||
|
scheduler: int='ddim'
|
||||||
|
precision: str='float16'
|
||||||
|
perlin: float=0.0
|
||||||
|
threshold: int=0.0
|
||||||
|
seamless: bool=False
|
||||||
|
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
||||||
|
h_symmetry_time_pct: float=None
|
||||||
|
v_symmetry_time_pct: float=None
|
||||||
|
variation_amount: float = 0.0
|
||||||
|
with_variations: list=field(default_factory=list)
|
||||||
|
safety_checker: SafetyChecker=None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InvokeAIGeneratorOutput:
|
||||||
|
'''
|
||||||
|
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
||||||
|
operation, including the image, its seed, the model name used to generate the image
|
||||||
|
and the model hash, as well as all the generate() parameters that went into
|
||||||
|
generating the image (in .params, also available as attributes)
|
||||||
|
'''
|
||||||
|
image: Image
|
||||||
|
seed: int
|
||||||
|
model_hash: str
|
||||||
|
attention_maps_images: List[Image]
|
||||||
|
params: Namespace
|
||||||
|
|
||||||
|
# we are interposing a wrapper around the original Generator classes so that
|
||||||
|
# old code that calls Generate will continue to work.
|
||||||
|
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||||
|
scheduler_map = dict(
|
||||||
|
ddim=diffusers.DDIMScheduler,
|
||||||
|
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||||
|
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||||
|
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_euler=diffusers.EulerDiscreteScheduler,
|
||||||
|
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||||
|
k_heun=diffusers.HeunDiscreteScheduler,
|
||||||
|
k_lms=diffusers.LMSDiscreteScheduler,
|
||||||
|
plms=diffusers.PNDMScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_info: dict,
|
||||||
|
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||||
|
):
|
||||||
|
self.model_info=model_info
|
||||||
|
self.params=params
|
||||||
|
|
||||||
|
def generate(self,
|
||||||
|
prompt: str='',
|
||||||
|
callback: callable=None,
|
||||||
|
step_callback: callable=None,
|
||||||
|
iterations: int=1,
|
||||||
|
**keyword_args,
|
||||||
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
|
'''
|
||||||
|
Return an iterator across the indicated number of generations.
|
||||||
|
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
||||||
|
object. Use like this:
|
||||||
|
|
||||||
|
outputs = txt2img.generate(prompt='banana sushi', iterations=5)
|
||||||
|
for result in outputs:
|
||||||
|
print(result.image, result.seed)
|
||||||
|
|
||||||
|
In the typical case of wanting to get just a single image, iterations
|
||||||
|
defaults to 1 and do:
|
||||||
|
|
||||||
|
output = next(txt2img.generate(prompt='banana sushi')
|
||||||
|
|
||||||
|
Pass None to get an infinite iterator.
|
||||||
|
|
||||||
|
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
|
||||||
|
for o in outputs:
|
||||||
|
print(o.image, o.seed)
|
||||||
|
|
||||||
|
'''
|
||||||
|
generator_args = dataclasses.asdict(self.params)
|
||||||
|
generator_args.update(keyword_args)
|
||||||
|
|
||||||
|
model_info = self.model_info
|
||||||
|
model_name = model_info['model_name']
|
||||||
|
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
|
model_hash = model_info['hash']
|
||||||
|
scheduler: Scheduler = self.get_scheduler(
|
||||||
|
model=model,
|
||||||
|
scheduler_name=generator_args.get('scheduler')
|
||||||
|
)
|
||||||
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||||
|
gen_class = self._generator_class()
|
||||||
|
generator = gen_class(model, self.params.precision)
|
||||||
|
if self.params.variation_amount > 0:
|
||||||
|
generator.set_variation(generator_args.get('seed'),
|
||||||
|
generator_args.get('variation_amount'),
|
||||||
|
generator_args.get('with_variations')
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(model, DiffusionPipeline):
|
||||||
|
for component in [model.unet, model.vae]:
|
||||||
|
configure_model_padding(component,
|
||||||
|
generator_args.get('seamless',False),
|
||||||
|
generator_args.get('seamless_axes')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
configure_model_padding(model,
|
||||||
|
generator_args.get('seamless',False),
|
||||||
|
generator_args.get('seamless_axes')
|
||||||
|
)
|
||||||
|
|
||||||
|
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||||
|
for i in iteration_count:
|
||||||
|
results = generator.generate(prompt,
|
||||||
|
conditioning=(uc, c, extra_conditioning_info),
|
||||||
|
sampler=scheduler,
|
||||||
|
**generator_args,
|
||||||
|
)
|
||||||
|
output = InvokeAIGeneratorOutput(
|
||||||
|
image=results[0][0],
|
||||||
|
seed=results[0][1],
|
||||||
|
attention_maps_images=results[0][2],
|
||||||
|
model_hash = model_hash,
|
||||||
|
params=Namespace(model_name=model_name,**generator_args),
|
||||||
|
)
|
||||||
|
if callback:
|
||||||
|
callback(output)
|
||||||
|
yield output
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def schedulers(self)->List[str]:
|
||||||
|
'''
|
||||||
|
Return list of all the schedulers that we currently handle.
|
||||||
|
'''
|
||||||
|
return list(self.scheduler_map.keys())
|
||||||
|
|
||||||
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||||
|
return generator_class(model, self.params.precision)
|
||||||
|
|
||||||
|
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
|
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||||
|
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||||
|
# hack copied over from generate.py
|
||||||
|
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||||
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generator_class(cls)->Type[Generator]:
|
||||||
|
'''
|
||||||
|
In derived classes return the name of the generator to apply.
|
||||||
|
If you don't override will return the name of the derived
|
||||||
|
class, which nicely parallels the generator class names.
|
||||||
|
'''
|
||||||
|
return Generator
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
class Txt2Img(InvokeAIGenerator):
|
||||||
|
@classmethod
|
||||||
|
def _generator_class(cls):
|
||||||
|
from .txt2img import Txt2Img
|
||||||
|
return Txt2Img
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
class Img2Img(InvokeAIGenerator):
|
||||||
|
def generate(self,
|
||||||
|
init_image: Image | torch.FloatTensor,
|
||||||
|
strength: float=0.75,
|
||||||
|
**keyword_args
|
||||||
|
)->List[InvokeAIGeneratorOutput]:
|
||||||
|
return super().generate(init_image=init_image,
|
||||||
|
strength=strength,
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
@classmethod
|
||||||
|
def _generator_class(cls):
|
||||||
|
from .img2img import Img2Img
|
||||||
|
return Img2Img
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||||
|
class Inpaint(Img2Img):
|
||||||
|
def generate(self,
|
||||||
|
mask_image: Image | torch.FloatTensor,
|
||||||
|
# Seam settings - when 0, doesn't fill seam
|
||||||
|
seam_size: int = 0,
|
||||||
|
seam_blur: int = 0,
|
||||||
|
seam_strength: float = 0.7,
|
||||||
|
seam_steps: int = 10,
|
||||||
|
tile_size: int = 32,
|
||||||
|
inpaint_replace=False,
|
||||||
|
infill_method=None,
|
||||||
|
inpaint_width=None,
|
||||||
|
inpaint_height=None,
|
||||||
|
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||||
|
**keyword_args
|
||||||
|
)->List[InvokeAIGeneratorOutput]:
|
||||||
|
return super().generate(
|
||||||
|
mask_image=mask_image,
|
||||||
|
seam_size=seam_size,
|
||||||
|
seam_blur=seam_blur,
|
||||||
|
seam_strength=seam_strength,
|
||||||
|
seam_steps=seam_steps,
|
||||||
|
tile_size=tile_size,
|
||||||
|
inpaint_replace=inpaint_replace,
|
||||||
|
infill_method=infill_method,
|
||||||
|
inpaint_width=inpaint_width,
|
||||||
|
inpaint_height=inpaint_height,
|
||||||
|
inpaint_fill=inpaint_fill,
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
@classmethod
|
||||||
|
def _generator_class(cls):
|
||||||
|
from .inpaint import Inpaint
|
||||||
|
return Inpaint
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
class Embiggen(Txt2Img):
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
embiggen: list=None,
|
||||||
|
embiggen_tiles: list = None,
|
||||||
|
strength: float=0.75,
|
||||||
|
**kwargs)->List[InvokeAIGeneratorOutput]:
|
||||||
|
return super().generate(embiggen=embiggen,
|
||||||
|
embiggen_tiles=embiggen_tiles,
|
||||||
|
strength=strength,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generator_class(cls):
|
||||||
|
from .embiggen import Embiggen
|
||||||
|
return Embiggen
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
downsampling_factor: int
|
downsampling_factor: int
|
||||||
@ -44,7 +293,6 @@ class Generator:
|
|||||||
self.with_variations = []
|
self.with_variations = []
|
||||||
self.use_mps_noise = False
|
self.use_mps_noise = False
|
||||||
self.free_gpu_mem = None
|
self.free_gpu_mem = None
|
||||||
self.caution_img = None
|
|
||||||
|
|
||||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||||
def get_make_image(self, prompt, **kwargs):
|
def get_make_image(self, prompt, **kwargs):
|
||||||
@ -64,10 +312,10 @@ class Generator:
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
init_image,
|
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
sampler,
|
sampler,
|
||||||
|
init_image=None,
|
||||||
iterations=1,
|
iterations=1,
|
||||||
seed=None,
|
seed=None,
|
||||||
image_callback=None,
|
image_callback=None,
|
||||||
@ -76,7 +324,7 @@ class Generator:
|
|||||||
perlin=0.0,
|
perlin=0.0,
|
||||||
h_symmetry_time_pct=None,
|
h_symmetry_time_pct=None,
|
||||||
v_symmetry_time_pct=None,
|
v_symmetry_time_pct=None,
|
||||||
safety_checker: dict = None,
|
safety_checker: SafetyChecker=None,
|
||||||
free_gpu_mem: bool = False,
|
free_gpu_mem: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -130,9 +378,9 @@ class Generator:
|
|||||||
image = make_image(x_T)
|
image = make_image(x_T)
|
||||||
|
|
||||||
if self.safety_checker is not None:
|
if self.safety_checker is not None:
|
||||||
image = self.safety_check(image)
|
image = self.safety_checker.check(image)
|
||||||
|
|
||||||
results.append([image, seed])
|
results.append([image, seed, attention_maps_images])
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
attention_maps_image = (
|
attention_maps_image = (
|
||||||
@ -292,16 +540,6 @@ class Generator:
|
|||||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
return (seed, initial_noise)
|
return (seed, initial_noise)
|
||||||
|
|
||||||
# returns a tensor filled with random numbers from a normal distribution
|
|
||||||
def get_noise(self, width, height):
|
|
||||||
"""
|
|
||||||
Returns a tensor filled with random numbers, either form a normal distribution
|
|
||||||
(txt2img) or from the latent image (img2img, inpaint)
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
"get_noise() must be implemented in a descendent class"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_perlin_noise(self, width, height):
|
def get_perlin_noise(self, width, height):
|
||||||
fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device
|
fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device
|
||||||
# limit noise to only the diffusion image channels, not the mask channels
|
# limit noise to only the diffusion image channels, not the mask channels
|
||||||
@ -361,53 +599,6 @@ class Generator:
|
|||||||
|
|
||||||
return v2
|
return v2
|
||||||
|
|
||||||
def safety_check(self, image: Image.Image):
|
|
||||||
"""
|
|
||||||
If the CompViz safety checker flags an NSFW image, we
|
|
||||||
blur it out.
|
|
||||||
"""
|
|
||||||
import diffusers
|
|
||||||
|
|
||||||
checker = self.safety_checker["checker"]
|
|
||||||
extractor = self.safety_checker["extractor"]
|
|
||||||
features = extractor([image], return_tensors="pt")
|
|
||||||
features.to(self.model.device)
|
|
||||||
|
|
||||||
# unfortunately checker requires the numpy version, so we have to convert back
|
|
||||||
x_image = np.array(image).astype(np.float32) / 255.0
|
|
||||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
|
||||||
|
|
||||||
diffusers.logging.set_verbosity_error()
|
|
||||||
checked_image, has_nsfw_concept = checker(
|
|
||||||
images=x_image, clip_input=features.pixel_values
|
|
||||||
)
|
|
||||||
if has_nsfw_concept[0]:
|
|
||||||
print(
|
|
||||||
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
|
||||||
)
|
|
||||||
return self.blur(image)
|
|
||||||
else:
|
|
||||||
return image
|
|
||||||
|
|
||||||
def blur(self, input):
|
|
||||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
|
||||||
try:
|
|
||||||
caution = self.get_caution_img()
|
|
||||||
if caution:
|
|
||||||
blurry.paste(caution, (0, 0), caution)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
return blurry
|
|
||||||
|
|
||||||
def get_caution_img(self):
|
|
||||||
path = None
|
|
||||||
if self.caution_img:
|
|
||||||
return self.caution_img
|
|
||||||
path = Path(web_assets.__path__[0]) / CAUTION_IMG
|
|
||||||
caution = Image.open(path)
|
|
||||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
|
||||||
return self.caution_img
|
|
||||||
|
|
||||||
# this is a handy routine for debugging use. Given a generated sample,
|
# this is a handy routine for debugging use. Given a generated sample,
|
||||||
# convert it into a PNG image and store it at the indicated path
|
# convert it into a PNG image and store it at the indicated path
|
||||||
def save_sample(self, sample, filepath):
|
def save_sample(self, sample, filepath):
|
||||||
|
@ -34,8 +34,7 @@ from picklescan.scanner import scan_file_path
|
|||||||
from invokeai.backend.globals import Globals, global_cache_dir
|
from invokeai.backend.globals import Globals, global_cache_dir
|
||||||
|
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
from ..util import CPU_DEVICE, ask_user, download_with_resume
|
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
|
||||||
|
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
class SDLegacyType(Enum):
|
||||||
V1 = 1
|
V1 = 1
|
||||||
@ -51,23 +50,29 @@ VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
|||||||
}
|
}
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
|
'''
|
||||||
|
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||||
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OmegaConf,
|
config: OmegaConf|Path,
|
||||||
device_type: torch.device = CPU_DEVICE,
|
device_type: torch.device = CUDA_DEVICE,
|
||||||
precision: str = "float16",
|
precision: str = "float16",
|
||||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||||
sequential_offload=False,
|
sequential_offload=False,
|
||||||
|
embedding_path: Path=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file,
|
Initialize with the path to the models.yaml config file or
|
||||||
the torch device type, and precision. The optional
|
an initialized OmegaConf dictionary. Optional parameters
|
||||||
min_avail_mem argument specifies how much unused system
|
are the torch device type, precision, max_loaded_models,
|
||||||
(CPU) memory to preserve. The cache of models in RAM will
|
and sequential_offload boolean. Note that the default device
|
||||||
grow until this value is approached. Default is 2G.
|
type and precision are set up for a CUDA system running at half precision.
|
||||||
"""
|
"""
|
||||||
# prevent nasty-looking CLIP log message
|
# prevent nasty-looking CLIP log message
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
if not isinstance(config, DictConfig):
|
||||||
|
config = OmegaConf.load(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.device = torch.device(device_type)
|
self.device = torch.device(device_type)
|
||||||
@ -76,6 +81,7 @@ class ModelManager(object):
|
|||||||
self.stack = [] # this is an LRU FIFO
|
self.stack = [] # this is an LRU FIFO
|
||||||
self.current_model = None
|
self.current_model = None
|
||||||
self.sequential_offload = sequential_offload
|
self.sequential_offload = sequential_offload
|
||||||
|
self.embedding_path = embedding_path
|
||||||
|
|
||||||
def valid_model(self, model_name: str) -> bool:
|
def valid_model(self, model_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -84,12 +90,15 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
return model_name in self.config
|
return model_name in self.config
|
||||||
|
|
||||||
def get_model(self, model_name: str):
|
def get_model(self, model_name: str=None)->dict:
|
||||||
"""
|
"""
|
||||||
Given a model named identified in models.yaml, return
|
Given a model named identified in models.yaml, return
|
||||||
the model object. If in RAM will load into GPU VRAM.
|
the model object. If in RAM will load into GPU VRAM.
|
||||||
If on disk, will load from there.
|
If on disk, will load from there.
|
||||||
"""
|
"""
|
||||||
|
if not model_name:
|
||||||
|
return self.current_model if self.current_model else self.get_model(self.default_model())
|
||||||
|
|
||||||
if not self.valid_model(model_name):
|
if not self.valid_model(model_name):
|
||||||
print(
|
print(
|
||||||
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
||||||
@ -112,6 +121,7 @@ class ModelManager(object):
|
|||||||
else: # we're about to load a new model, so potentially offload the least recently used one
|
else: # we're about to load a new model, so potentially offload the least recently used one
|
||||||
requested_model, width, height, hash = self._load_model(model_name)
|
requested_model, width, height, hash = self._load_model(model_name)
|
||||||
self.models[model_name] = {
|
self.models[model_name] = {
|
||||||
|
"model_name": model_name,
|
||||||
"model": requested_model,
|
"model": requested_model,
|
||||||
"width": width,
|
"width": width,
|
||||||
"height": height,
|
"height": height,
|
||||||
@ -121,6 +131,7 @@ class ModelManager(object):
|
|||||||
self.current_model = model_name
|
self.current_model = model_name
|
||||||
self._push_newest_model(model_name)
|
self._push_newest_model(model_name)
|
||||||
return {
|
return {
|
||||||
|
"model_name": model_name,
|
||||||
"model": requested_model,
|
"model": requested_model,
|
||||||
"width": width,
|
"width": width,
|
||||||
"height": height,
|
"height": height,
|
||||||
@ -425,6 +436,7 @@ class ModelManager(object):
|
|||||||
height = width
|
height = width
|
||||||
|
|
||||||
print(f" | Default image dimensions = {width} x {height}")
|
print(f" | Default image dimensions = {width} x {height}")
|
||||||
|
self._add_embeddings_to_model(pipeline)
|
||||||
|
|
||||||
return pipeline, width, height, model_hash
|
return pipeline, width, height, model_hash
|
||||||
|
|
||||||
@ -1061,6 +1073,19 @@ class ModelManager(object):
|
|||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
self.stack.append(model_name)
|
self.stack.append(model_name)
|
||||||
|
|
||||||
|
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
||||||
|
if self.embedding_path is not None:
|
||||||
|
print(f">> Loading embeddings from {self.embedding_path}")
|
||||||
|
for root, _, files in os.walk(self.embedding_path):
|
||||||
|
for name in files:
|
||||||
|
ti_path = os.path.join(root, name)
|
||||||
|
model.textual_inversion_manager.load_textual_inversion(
|
||||||
|
ti_path, defer_injecting_tokens=True
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||||
|
)
|
||||||
|
|
||||||
def _has_cuda(self) -> bool:
|
def _has_cuda(self) -> bool:
|
||||||
return self.device.type == "cuda"
|
return self.device.type == "cuda"
|
||||||
|
|
||||||
|
82
invokeai/backend/safety_checker.py
Normal file
82
invokeai/backend/safety_checker.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
'''
|
||||||
|
SafetyChecker class - checks images against the StabilityAI NSFW filter
|
||||||
|
and blurs images that contain potential NSFW content.
|
||||||
|
'''
|
||||||
|
import diffusers
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import traceback
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||||
|
StableDiffusionSafetyChecker,
|
||||||
|
)
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image, ImageFilter
|
||||||
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
|
import invokeai.assets.web as web_assets
|
||||||
|
from .globals import global_cache_dir
|
||||||
|
from .util import CPU_DEVICE
|
||||||
|
|
||||||
|
class SafetyChecker(object):
|
||||||
|
CAUTION_IMG = "caution.png"
|
||||||
|
|
||||||
|
def __init__(self, device: torch.device):
|
||||||
|
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
|
||||||
|
caution = Image.open(path)
|
||||||
|
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
try:
|
||||||
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
|
safety_model_path = global_cache_dir("hub")
|
||||||
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||||
|
safety_model_id,
|
||||||
|
local_files_only=True,
|
||||||
|
cache_dir=safety_model_path,
|
||||||
|
)
|
||||||
|
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
|
safety_model_id,
|
||||||
|
local_files_only=True,
|
||||||
|
cache_dir=safety_model_path,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
print(
|
||||||
|
"** An error was encountered while installing the safety checker:"
|
||||||
|
)
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
def check(self, image: Image.Image):
|
||||||
|
"""
|
||||||
|
Check provided image against the StabilityAI safety checker and return
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.safety_checker.to(self.device)
|
||||||
|
features = self.safety_feature_extractor([image], return_tensors="pt")
|
||||||
|
features.to(self.device)
|
||||||
|
|
||||||
|
# unfortunately checker requires the numpy version, so we have to convert back
|
||||||
|
x_image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||||
|
|
||||||
|
diffusers.logging.set_verbosity_error()
|
||||||
|
checked_image, has_nsfw_concept = self.safety_checker(
|
||||||
|
images=x_image, clip_input=features.pixel_values
|
||||||
|
)
|
||||||
|
self.safety_checker.to(CPU_DEVICE) # offload
|
||||||
|
if has_nsfw_concept[0]:
|
||||||
|
print(
|
||||||
|
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||||
|
)
|
||||||
|
return self.blur(image)
|
||||||
|
else:
|
||||||
|
return image
|
||||||
|
|
||||||
|
def blur(self, input):
|
||||||
|
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||||
|
try:
|
||||||
|
if caution := self.caution_img:
|
||||||
|
blurry.paste(caution, (0, 0), caution)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
return blurry
|
BIN
static/dream_web/favicon.ico
Normal file
BIN
static/dream_web/favicon.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.1 KiB |
179
static/dream_web/index.css
Normal file
179
static/dream_web/index.css
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
:root {
|
||||||
|
--fields-dark:#DCDCDC;
|
||||||
|
--fields-light:#F5F5F5;
|
||||||
|
}
|
||||||
|
|
||||||
|
* {
|
||||||
|
font-family: 'Arial';
|
||||||
|
font-size: 100%;
|
||||||
|
}
|
||||||
|
body {
|
||||||
|
font-size: 1em;
|
||||||
|
}
|
||||||
|
textarea {
|
||||||
|
font-size: 0.95em;
|
||||||
|
}
|
||||||
|
header, form, #progress-section {
|
||||||
|
margin-left: auto;
|
||||||
|
margin-right: auto;
|
||||||
|
max-width: 1024px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
fieldset {
|
||||||
|
border: none;
|
||||||
|
line-height: 2.2em;
|
||||||
|
}
|
||||||
|
fieldset > legend {
|
||||||
|
width: auto;
|
||||||
|
margin-left: 0;
|
||||||
|
margin-right: auto;
|
||||||
|
font-weight:bold;
|
||||||
|
}
|
||||||
|
select, input {
|
||||||
|
margin-right: 10px;
|
||||||
|
padding: 2px;
|
||||||
|
}
|
||||||
|
input:disabled {
|
||||||
|
cursor:auto;
|
||||||
|
}
|
||||||
|
input[type=submit] {
|
||||||
|
cursor: pointer;
|
||||||
|
background-color: #666;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
input[type=checkbox] {
|
||||||
|
cursor: pointer;
|
||||||
|
margin-right: 0px;
|
||||||
|
width: 20px;
|
||||||
|
height: 20px;
|
||||||
|
vertical-align: middle;
|
||||||
|
}
|
||||||
|
input#seed {
|
||||||
|
margin-right: 0px;
|
||||||
|
}
|
||||||
|
div {
|
||||||
|
padding: 10px 10px 10px 10px;
|
||||||
|
}
|
||||||
|
header {
|
||||||
|
margin-bottom: 16px;
|
||||||
|
}
|
||||||
|
header h1 {
|
||||||
|
margin-bottom: 0;
|
||||||
|
font-size: 2em;
|
||||||
|
}
|
||||||
|
#search-box {
|
||||||
|
display: flex;
|
||||||
|
}
|
||||||
|
#scaling-inprocess-message {
|
||||||
|
font-weight: bold;
|
||||||
|
font-style: italic;
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
#prompt {
|
||||||
|
flex-grow: 1;
|
||||||
|
padding: 5px 10px 5px 10px;
|
||||||
|
border: 1px solid #999;
|
||||||
|
outline: none;
|
||||||
|
}
|
||||||
|
#submit {
|
||||||
|
padding: 5px 10px 5px 10px;
|
||||||
|
border: 1px solid #999;
|
||||||
|
}
|
||||||
|
#reset-all, #remove-image {
|
||||||
|
margin-top: 12px;
|
||||||
|
font-size: 0.8em;
|
||||||
|
background-color: pink;
|
||||||
|
border: 1px solid #999;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
#results {
|
||||||
|
text-align: center;
|
||||||
|
margin: auto;
|
||||||
|
padding-top: 10px;
|
||||||
|
}
|
||||||
|
#results figure {
|
||||||
|
display: inline-block;
|
||||||
|
margin: 10px;
|
||||||
|
}
|
||||||
|
#results figcaption {
|
||||||
|
font-size: 0.8em;
|
||||||
|
padding: 3px;
|
||||||
|
color: #888;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
#results img {
|
||||||
|
border-radius: 5px;
|
||||||
|
object-fit: contain;
|
||||||
|
background-color: var(--fields-dark);
|
||||||
|
}
|
||||||
|
#fieldset-config {
|
||||||
|
line-height:2em;
|
||||||
|
}
|
||||||
|
input[type="number"] {
|
||||||
|
width: 60px;
|
||||||
|
}
|
||||||
|
#seed {
|
||||||
|
width: 150px;
|
||||||
|
}
|
||||||
|
button#reset-seed {
|
||||||
|
font-size: 1.7em;
|
||||||
|
background: #efefef;
|
||||||
|
border: 1px solid #999;
|
||||||
|
border-radius: 4px;
|
||||||
|
line-height: 0.8;
|
||||||
|
margin: 0 10px 0 0;
|
||||||
|
padding: 0 5px 3px;
|
||||||
|
vertical-align: middle;
|
||||||
|
}
|
||||||
|
label {
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
#progress-section {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
#progress-image {
|
||||||
|
width: 30vh;
|
||||||
|
height: 30vh;
|
||||||
|
object-fit: contain;
|
||||||
|
background-color: var(--fields-dark);
|
||||||
|
}
|
||||||
|
#cancel-button {
|
||||||
|
cursor: pointer;
|
||||||
|
color: red;
|
||||||
|
}
|
||||||
|
#txt2img {
|
||||||
|
background-color: var(--fields-dark);
|
||||||
|
}
|
||||||
|
#variations {
|
||||||
|
background-color: var(--fields-light);
|
||||||
|
}
|
||||||
|
#initimg {
|
||||||
|
background-color: var(--fields-dark);
|
||||||
|
}
|
||||||
|
#img2img {
|
||||||
|
background-color: var(--fields-light);
|
||||||
|
}
|
||||||
|
#initimg > :not(legend) {
|
||||||
|
background-color: var(--fields-light);
|
||||||
|
margin: .5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
#postprocess, #initimg {
|
||||||
|
display:flex;
|
||||||
|
flex-wrap:wrap;
|
||||||
|
padding: 0;
|
||||||
|
margin-top: 1em;
|
||||||
|
background-color: var(--fields-dark);
|
||||||
|
}
|
||||||
|
#postprocess > fieldset, #initimg > * {
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
#postprocess > fieldset {
|
||||||
|
background-color: var(--fields-dark);
|
||||||
|
}
|
||||||
|
#progress-section {
|
||||||
|
background-color: var(--fields-light);
|
||||||
|
}
|
||||||
|
#no-results-message:not(:only-child) {
|
||||||
|
display: none;
|
||||||
|
}
|
187
static/dream_web/index.html
Normal file
187
static/dream_web/index.html
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
<html lang="en">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<title>Stable Diffusion Dream Server</title>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<link rel="icon" type="image/x-icon" href="static/dream_web/favicon.ico" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
|
||||||
|
<script src="config.js"></script>
|
||||||
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"
|
||||||
|
integrity="sha512-q/dWJ3kcmjBLU4Qc47E4A9kTB4m3wuTY7vkFJDTZKjTs8jhyGQnaUrxa0Ytd0ssMZhbNua9hE+E7Qv1j+DyZwA=="
|
||||||
|
crossorigin="anonymous"></script>
|
||||||
|
<link rel="stylesheet" href="index.css">
|
||||||
|
<script src="index.js"></script>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<header>
|
||||||
|
<h1>Stable Diffusion Dream Server</h1>
|
||||||
|
<div id="about">
|
||||||
|
For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub
|
||||||
|
site</a>
|
||||||
|
</div>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
<main>
|
||||||
|
<!--
|
||||||
|
<div id="dropper" style="background-color:red;width:200px;height:200px;">
|
||||||
|
</div>
|
||||||
|
-->
|
||||||
|
<form id="generate-form" method="post" action="api/jobs">
|
||||||
|
<fieldset id="txt2img">
|
||||||
|
<legend>
|
||||||
|
<input type="checkbox" name="enable_generate" id="enable_generate" checked>
|
||||||
|
<label for="enable_generate">Generate</label>
|
||||||
|
</legend>
|
||||||
|
<div id="search-box">
|
||||||
|
<textarea rows="3" id="prompt" name="prompt"></textarea>
|
||||||
|
</div>
|
||||||
|
<label for="iterations">Images to generate:</label>
|
||||||
|
<input value="1" type="number" id="iterations" name="iterations" size="4">
|
||||||
|
<label for="steps">Steps:</label>
|
||||||
|
<input value="50" type="number" id="steps" name="steps">
|
||||||
|
<label for="cfg_scale">Cfg Scale:</label>
|
||||||
|
<input value="7.5" type="number" id="cfg_scale" name="cfg_scale" step="any">
|
||||||
|
<label for="sampler_name">Sampler:</label>
|
||||||
|
<select id="sampler_name" name="sampler_name" value="k_lms">
|
||||||
|
<option value="ddim">DDIM</option>
|
||||||
|
<option value="plms">PLMS</option>
|
||||||
|
<option value="k_lms" selected>KLMS</option>
|
||||||
|
<option value="k_dpm_2">KDPM_2</option>
|
||||||
|
<option value="k_dpm_2_a">KDPM_2A</option>
|
||||||
|
<option value="k_dpmpp_2">KDPMPP_2</option>
|
||||||
|
<option value="k_dpmpp_2_a">KDPMPP_2A</option>
|
||||||
|
<option value="k_euler">KEULER</option>
|
||||||
|
<option value="k_euler_a">KEULER_A</option>
|
||||||
|
<option value="k_heun">KHEUN</option>
|
||||||
|
</select>
|
||||||
|
<input type="checkbox" name="seamless" id="seamless">
|
||||||
|
<label for="seamless">Seamless circular tiling</label>
|
||||||
|
<br>
|
||||||
|
<label title="Set to multiple of 64" for="width">Width:</label>
|
||||||
|
<select id="width" name="width" value="512">
|
||||||
|
<option value="64">64</option>
|
||||||
|
<option value="128">128</option>
|
||||||
|
<option value="192">192</option>
|
||||||
|
<option value="256">256</option>
|
||||||
|
<option value="320">320</option>
|
||||||
|
<option value="384">384</option>
|
||||||
|
<option value="448">448</option>
|
||||||
|
<option value="512" selected>512</option>
|
||||||
|
<option value="576">576</option>
|
||||||
|
<option value="640">640</option>
|
||||||
|
<option value="704">704</option>
|
||||||
|
<option value="768">768</option>
|
||||||
|
<option value="832">832</option>
|
||||||
|
<option value="896">896</option>
|
||||||
|
<option value="960">960</option>
|
||||||
|
<option value="1024">1024</option>
|
||||||
|
</select>
|
||||||
|
<label title="Set to multiple of 64" for="height">Height:</label>
|
||||||
|
<select id="height" name="height" value="512">
|
||||||
|
<option value="64">64</option>
|
||||||
|
<option value="128">128</option>
|
||||||
|
<option value="192">192</option>
|
||||||
|
<option value="256">256</option>
|
||||||
|
<option value="320">320</option>
|
||||||
|
<option value="384">384</option>
|
||||||
|
<option value="448">448</option>
|
||||||
|
<option value="512" selected>512</option>
|
||||||
|
<option value="576">576</option>
|
||||||
|
<option value="640">640</option>
|
||||||
|
<option value="704">704</option>
|
||||||
|
<option value="768">768</option>
|
||||||
|
<option value="832">832</option>
|
||||||
|
<option value="896">896</option>
|
||||||
|
<option value="960">960</option>
|
||||||
|
<option value="1024">1024</option>
|
||||||
|
</select>
|
||||||
|
<label title="Set to 0 for random seed" for="seed">Seed:</label>
|
||||||
|
<input value="0" type="number" id="seed" name="seed">
|
||||||
|
<button type="button" id="reset-seed">↺</button>
|
||||||
|
<input type="checkbox" name="progress_images" id="progress_images">
|
||||||
|
<label for="progress_images">Display in-progress images (slower)</label>
|
||||||
|
<div>
|
||||||
|
<label title="If > 0, adds thresholding to restrict values for k-diffusion samplers (0 disables)" for="threshold">Threshold:</label>
|
||||||
|
<input value="0" type="number" id="threshold" name="threshold" step="0.1" min="0">
|
||||||
|
<label title="Perlin: optional 0-1 value adds a percentage of perlin noise to the initial noise" for="perlin">Perlin:</label>
|
||||||
|
<input value="0" type="number" id="perlin" name="perlin" step="0.01" min="0" max="1">
|
||||||
|
<button type="button" id="reset-all">Reset to Defaults</button>
|
||||||
|
</div>
|
||||||
|
<div id="variations">
|
||||||
|
<label
|
||||||
|
title="If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different."
|
||||||
|
for="variation_amount">Variation amount (0 to disable):</label>
|
||||||
|
<input value="0" type="number" id="variation_amount" name="variation_amount" step="0.01" min="0" max="1">
|
||||||
|
<label title="list of variations to apply, in the format `seed:weight,seed:weight,..."
|
||||||
|
for="with_variations">With variations (seed:weight,seed:weight,...):</label>
|
||||||
|
<input value="" type="text" id="with_variations" name="with_variations">
|
||||||
|
</div>
|
||||||
|
</fieldset>
|
||||||
|
<fieldset id="initimg">
|
||||||
|
<legend>
|
||||||
|
<input type="checkbox" name="enable_init_image" id="enable_init_image" checked>
|
||||||
|
<label for="enable_init_image">Enable init image</label>
|
||||||
|
</legend>
|
||||||
|
<div>
|
||||||
|
<label title="Upload an image to use img2img" for="initimg">Initial image:</label>
|
||||||
|
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
|
||||||
|
<button type="button" id="remove-image">Remove Image</button>
|
||||||
|
</div>
|
||||||
|
<fieldset id="img2img">
|
||||||
|
<legend>
|
||||||
|
<input type="checkbox" name="enable_img2img" id="enable_img2img" checked>
|
||||||
|
<label for="enable_img2img">Enable Img2Img</label>
|
||||||
|
</legend>
|
||||||
|
<label for="strength">Img2Img Strength:</label>
|
||||||
|
<input value="0.75" type="number" id="strength" name="strength" step="0.01" min="0" max="1">
|
||||||
|
<input type="checkbox" id="fit" name="fit" checked>
|
||||||
|
<label title="Rescale image to fit within requested width and height" for="fit">Fit to width/height:</label>
|
||||||
|
</fieldset>
|
||||||
|
</fieldset>
|
||||||
|
<div id="postprocess">
|
||||||
|
<fieldset id="gfpgan">
|
||||||
|
<legend>
|
||||||
|
<input type="checkbox" name="enable_gfpgan" id="enable_gfpgan">
|
||||||
|
<label for="enable_gfpgan">Enable gfpgan</label>
|
||||||
|
</legend>
|
||||||
|
<label title="Strength of the gfpgan (face fixing) algorithm." for="facetool_strength">GPFGAN Strength:</label>
|
||||||
|
<input value="0.8" min="0" max="1" type="number" id="facetool_strength" name="facetool_strength" step="0.05">
|
||||||
|
</fieldset>
|
||||||
|
<fieldset id="upscale">
|
||||||
|
<legend>
|
||||||
|
<input type="checkbox" name="enable_upscale" id="enable_upscale">
|
||||||
|
<label for="enable_upscale">Enable Upscaling</label>
|
||||||
|
</legend>
|
||||||
|
<label title="Upscaling to perform using ESRGAN." for="upscale_level">Upscaling Level:</label>
|
||||||
|
<select id="upscale_level" name="upscale_level" value="">
|
||||||
|
<option value="" selected>None</option>
|
||||||
|
<option value="2">2x</option>
|
||||||
|
<option value="4">4x</option>
|
||||||
|
</select>
|
||||||
|
<label title="Strength of the esrgan (upscaling) algorithm." for="upscale_strength">Upscale Strength:</label>
|
||||||
|
<input value="0.75" min="0" max="1" type="number" id="upscale_strength" name="upscale_strength" step="0.05">
|
||||||
|
</fieldset>
|
||||||
|
</div>
|
||||||
|
<input type="submit" id="submit" value="Generate">
|
||||||
|
</form>
|
||||||
|
<br>
|
||||||
|
<section id="progress-section">
|
||||||
|
<div id="progress-container">
|
||||||
|
<progress id="progress-bar" value="0" max="1"></progress>
|
||||||
|
<span id="cancel-button" title="Cancel">✖</span>
|
||||||
|
<br>
|
||||||
|
<img id="progress-image" src='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'>
|
||||||
|
<div id="scaling-inprocess-message">
|
||||||
|
<i><span>Postprocessing...</span><span id="processing_cnt">1</span>/<span id="processing_total">3</span></i>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<div id="results">
|
||||||
|
</div>
|
||||||
|
</main>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
396
static/dream_web/index.js
Normal file
396
static/dream_web/index.js
Normal file
@ -0,0 +1,396 @@
|
|||||||
|
const socket = io();
|
||||||
|
|
||||||
|
var priorResultsLoadState = {
|
||||||
|
page: 0,
|
||||||
|
pages: 1,
|
||||||
|
per_page: 10,
|
||||||
|
total: 20,
|
||||||
|
offset: 0, // number of items generated since last load
|
||||||
|
loading: false,
|
||||||
|
initialized: false
|
||||||
|
};
|
||||||
|
|
||||||
|
function loadPriorResults() {
|
||||||
|
// Fix next page by offset
|
||||||
|
let offsetPages = priorResultsLoadState.offset / priorResultsLoadState.per_page;
|
||||||
|
priorResultsLoadState.page += offsetPages;
|
||||||
|
priorResultsLoadState.pages += offsetPages;
|
||||||
|
priorResultsLoadState.total += priorResultsLoadState.offset;
|
||||||
|
priorResultsLoadState.offset = 0;
|
||||||
|
|
||||||
|
if (priorResultsLoadState.loading) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (priorResultsLoadState.page >= priorResultsLoadState.pages) {
|
||||||
|
return; // Nothing more to load
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load
|
||||||
|
priorResultsLoadState.loading = true
|
||||||
|
let url = new URL('/api/images', document.baseURI);
|
||||||
|
url.searchParams.append('page', priorResultsLoadState.initialized ? priorResultsLoadState.page + 1 : priorResultsLoadState.page);
|
||||||
|
url.searchParams.append('per_page', priorResultsLoadState.per_page);
|
||||||
|
fetch(url.href, {
|
||||||
|
method: 'GET',
|
||||||
|
headers: new Headers({'content-type': 'application/json'})
|
||||||
|
})
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
priorResultsLoadState.page = data.page;
|
||||||
|
priorResultsLoadState.pages = data.pages;
|
||||||
|
priorResultsLoadState.per_page = data.per_page;
|
||||||
|
priorResultsLoadState.total = data.total;
|
||||||
|
|
||||||
|
data.items.forEach(function(dreamId, index) {
|
||||||
|
let src = 'api/images/' + dreamId;
|
||||||
|
fetch('/api/images/' + dreamId + '/metadata', {
|
||||||
|
method: 'GET',
|
||||||
|
headers: new Headers({'content-type': 'application/json'})
|
||||||
|
})
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(metadata => {
|
||||||
|
let seed = metadata.seed || 0; // TODO: Parse old metadata
|
||||||
|
appendOutput(src, seed, metadata, true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Load until page is full
|
||||||
|
if (!priorResultsLoadState.initialized) {
|
||||||
|
if (document.body.scrollHeight <= window.innerHeight) {
|
||||||
|
loadPriorResults();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.finally(() => {
|
||||||
|
priorResultsLoadState.loading = false;
|
||||||
|
priorResultsLoadState.initialized = true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function resetForm() {
|
||||||
|
var form = document.getElementById('generate-form');
|
||||||
|
form.querySelector('fieldset').removeAttribute('disabled');
|
||||||
|
}
|
||||||
|
|
||||||
|
function initProgress(totalSteps, showProgressImages) {
|
||||||
|
// TODO: Progress could theoretically come from multiple jobs at the same time (in the future)
|
||||||
|
let progressSectionEle = document.querySelector('#progress-section');
|
||||||
|
progressSectionEle.style.display = 'initial';
|
||||||
|
let progressEle = document.querySelector('#progress-bar');
|
||||||
|
progressEle.setAttribute('max', totalSteps);
|
||||||
|
|
||||||
|
let progressImageEle = document.querySelector('#progress-image');
|
||||||
|
progressImageEle.src = BLANK_IMAGE_URL;
|
||||||
|
progressImageEle.style.display = showProgressImages ? 'initial': 'none';
|
||||||
|
}
|
||||||
|
|
||||||
|
function setProgress(step, totalSteps, src) {
|
||||||
|
let progressEle = document.querySelector('#progress-bar');
|
||||||
|
progressEle.setAttribute('value', step);
|
||||||
|
|
||||||
|
if (src) {
|
||||||
|
let progressImageEle = document.querySelector('#progress-image');
|
||||||
|
progressImageEle.src = src;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function resetProgress(hide = true) {
|
||||||
|
if (hide) {
|
||||||
|
let progressSectionEle = document.querySelector('#progress-section');
|
||||||
|
progressSectionEle.style.display = 'none';
|
||||||
|
}
|
||||||
|
let progressEle = document.querySelector('#progress-bar');
|
||||||
|
progressEle.setAttribute('value', 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
function toBase64(file) {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const r = new FileReader();
|
||||||
|
r.readAsDataURL(file);
|
||||||
|
r.onload = () => resolve(r.result);
|
||||||
|
r.onerror = (error) => reject(error);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function ondragdream(event) {
|
||||||
|
let dream = event.target.dataset.dream;
|
||||||
|
event.dataTransfer.setData("dream", dream);
|
||||||
|
}
|
||||||
|
|
||||||
|
function seedClick(event) {
|
||||||
|
// Get element
|
||||||
|
var image = event.target.closest('figure').querySelector('img');
|
||||||
|
var dream = JSON.parse(decodeURIComponent(image.dataset.dream));
|
||||||
|
|
||||||
|
let form = document.querySelector("#generate-form");
|
||||||
|
for (const [k, v] of new FormData(form)) {
|
||||||
|
if (k == 'initimg') { continue; }
|
||||||
|
let formElem = form.querySelector(`*[name=${k}]`);
|
||||||
|
formElem.value = dream[k] !== undefined ? dream[k] : formElem.defaultValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
document.querySelector("#seed").value = dream.seed;
|
||||||
|
document.querySelector('#iterations').value = 1; // Reset to 1 iteration since we clicked a single image (not a full job)
|
||||||
|
|
||||||
|
// NOTE: leaving this manual for the user for now - it was very confusing with this behavior
|
||||||
|
// document.querySelector("#with_variations").value = variations || '';
|
||||||
|
// if (document.querySelector("#variation_amount").value <= 0) {
|
||||||
|
// document.querySelector("#variation_amount").value = 0.2;
|
||||||
|
// }
|
||||||
|
|
||||||
|
saveFields(document.querySelector("#generate-form"));
|
||||||
|
}
|
||||||
|
|
||||||
|
function appendOutput(src, seed, config, toEnd=false) {
|
||||||
|
let outputNode = document.createElement("figure");
|
||||||
|
let altText = seed.toString() + " | " + config.prompt;
|
||||||
|
|
||||||
|
// img needs width and height for lazy loading to work
|
||||||
|
// TODO: store the full config in a data attribute on the image?
|
||||||
|
const figureContents = `
|
||||||
|
<a href="${src}" target="_blank">
|
||||||
|
<img src="${src}"
|
||||||
|
alt="${altText}"
|
||||||
|
title="${altText}"
|
||||||
|
loading="lazy"
|
||||||
|
width="256"
|
||||||
|
height="256"
|
||||||
|
draggable="true"
|
||||||
|
ondragstart="ondragdream(event, this)"
|
||||||
|
data-dream="${encodeURIComponent(JSON.stringify(config))}"
|
||||||
|
data-dreamId="${encodeURIComponent(config.dreamId)}">
|
||||||
|
</a>
|
||||||
|
<figcaption onclick="seedClick(event, this)">${seed}</figcaption>
|
||||||
|
`;
|
||||||
|
|
||||||
|
outputNode.innerHTML = figureContents;
|
||||||
|
|
||||||
|
if (toEnd) {
|
||||||
|
document.querySelector("#results").append(outputNode);
|
||||||
|
} else {
|
||||||
|
document.querySelector("#results").prepend(outputNode);
|
||||||
|
}
|
||||||
|
document.querySelector("#no-results-message")?.remove();
|
||||||
|
}
|
||||||
|
|
||||||
|
function saveFields(form) {
|
||||||
|
for (const [k, v] of new FormData(form)) {
|
||||||
|
if (typeof v !== 'object') { // Don't save 'file' type
|
||||||
|
localStorage.setItem(k, v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function loadFields(form) {
|
||||||
|
for (const [k, v] of new FormData(form)) {
|
||||||
|
const item = localStorage.getItem(k);
|
||||||
|
if (item != null) {
|
||||||
|
form.querySelector(`*[name=${k}]`).value = item;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function clearFields(form) {
|
||||||
|
localStorage.clear();
|
||||||
|
let prompt = form.prompt.value;
|
||||||
|
form.reset();
|
||||||
|
form.prompt.value = prompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
const BLANK_IMAGE_URL = 'data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>';
|
||||||
|
async function generateSubmit(form) {
|
||||||
|
// Convert file data to base64
|
||||||
|
// TODO: Should probably uplaod files with formdata or something, and store them in the backend?
|
||||||
|
let formData = Object.fromEntries(new FormData(form));
|
||||||
|
if (!formData.enable_generate && !formData.enable_init_image) {
|
||||||
|
gen_label = document.querySelector("label[for=enable_generate]").innerHTML;
|
||||||
|
initimg_label = document.querySelector("label[for=enable_init_image]").innerHTML;
|
||||||
|
alert(`Error: one of "${gen_label}" or "${initimg_label}" must be set`);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
formData.initimg_name = formData.initimg.name
|
||||||
|
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
|
||||||
|
|
||||||
|
// Evaluate all checkboxes
|
||||||
|
let checkboxes = form.querySelectorAll('input[type=checkbox]');
|
||||||
|
checkboxes.forEach(function (checkbox) {
|
||||||
|
if (checkbox.checked) {
|
||||||
|
formData[checkbox.name] = 'true';
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let strength = formData.strength;
|
||||||
|
let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps;
|
||||||
|
let showProgressImages = formData.progress_images;
|
||||||
|
|
||||||
|
// Set enabling flags
|
||||||
|
|
||||||
|
|
||||||
|
// Initialize the progress bar
|
||||||
|
initProgress(totalSteps, showProgressImages);
|
||||||
|
|
||||||
|
// POST, use response to listen for events
|
||||||
|
fetch(form.action, {
|
||||||
|
method: form.method,
|
||||||
|
headers: new Headers({'content-type': 'application/json'}),
|
||||||
|
body: JSON.stringify(formData),
|
||||||
|
})
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
var jobId = data.jobId;
|
||||||
|
socket.emit('join_room', { 'room': jobId });
|
||||||
|
});
|
||||||
|
|
||||||
|
form.querySelector('fieldset').setAttribute('disabled','');
|
||||||
|
}
|
||||||
|
|
||||||
|
function fieldSetEnableChecked(event) {
|
||||||
|
cb = event.target;
|
||||||
|
fields = cb.closest('fieldset');
|
||||||
|
fields.disabled = !cb.checked;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Socket listeners
|
||||||
|
socket.on('job_started', (data) => {})
|
||||||
|
|
||||||
|
socket.on('dream_result', (data) => {
|
||||||
|
var jobId = data.jobId;
|
||||||
|
var dreamId = data.dreamId;
|
||||||
|
var dreamRequest = data.dreamRequest;
|
||||||
|
var src = 'api/images/' + dreamId;
|
||||||
|
|
||||||
|
priorResultsLoadState.offset += 1;
|
||||||
|
appendOutput(src, dreamRequest.seed, dreamRequest);
|
||||||
|
|
||||||
|
resetProgress(false);
|
||||||
|
})
|
||||||
|
|
||||||
|
socket.on('dream_progress', (data) => {
|
||||||
|
// TODO: it'd be nice if we could get a seed reported here, but the generator would need to be updated
|
||||||
|
var step = data.step;
|
||||||
|
var totalSteps = data.totalSteps;
|
||||||
|
var jobId = data.jobId;
|
||||||
|
var dreamId = data.dreamId;
|
||||||
|
|
||||||
|
var progressType = data.progressType
|
||||||
|
if (progressType === 'GENERATION') {
|
||||||
|
var src = data.hasProgressImage ?
|
||||||
|
'api/intermediates/' + dreamId + '/' + step
|
||||||
|
: null;
|
||||||
|
setProgress(step, totalSteps, src);
|
||||||
|
} else if (progressType === 'UPSCALING_STARTED') {
|
||||||
|
// step and totalSteps are used for upscale count on this message
|
||||||
|
document.getElementById("processing_cnt").textContent = step;
|
||||||
|
document.getElementById("processing_total").textContent = totalSteps;
|
||||||
|
document.getElementById("scaling-inprocess-message").style.display = "block";
|
||||||
|
} else if (progressType == 'UPSCALING_DONE') {
|
||||||
|
document.getElementById("scaling-inprocess-message").style.display = "none";
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
socket.on('job_canceled', (data) => {
|
||||||
|
resetForm();
|
||||||
|
resetProgress();
|
||||||
|
})
|
||||||
|
|
||||||
|
socket.on('job_done', (data) => {
|
||||||
|
jobId = data.jobId
|
||||||
|
socket.emit('leave_room', { 'room': jobId });
|
||||||
|
|
||||||
|
resetForm();
|
||||||
|
resetProgress();
|
||||||
|
})
|
||||||
|
|
||||||
|
window.onload = async () => {
|
||||||
|
document.querySelector("#prompt").addEventListener("keydown", (e) => {
|
||||||
|
if (e.key === "Enter" && !e.shiftKey) {
|
||||||
|
const form = e.target.form;
|
||||||
|
generateSubmit(form);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
document.querySelector("#generate-form").addEventListener('submit', (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
const form = e.target;
|
||||||
|
|
||||||
|
generateSubmit(form);
|
||||||
|
});
|
||||||
|
document.querySelector("#generate-form").addEventListener('change', (e) => {
|
||||||
|
saveFields(e.target.form);
|
||||||
|
});
|
||||||
|
document.querySelector("#reset-seed").addEventListener('click', (e) => {
|
||||||
|
document.querySelector("#seed").value = 0;
|
||||||
|
saveFields(e.target.form);
|
||||||
|
});
|
||||||
|
document.querySelector("#reset-all").addEventListener('click', (e) => {
|
||||||
|
clearFields(e.target.form);
|
||||||
|
});
|
||||||
|
document.querySelector("#remove-image").addEventListener('click', (e) => {
|
||||||
|
initimg.value=null;
|
||||||
|
});
|
||||||
|
loadFields(document.querySelector("#generate-form"));
|
||||||
|
|
||||||
|
document.querySelector('#cancel-button').addEventListener('click', () => {
|
||||||
|
fetch('/api/cancel').catch(e => {
|
||||||
|
console.error(e);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
document.documentElement.addEventListener('keydown', (e) => {
|
||||||
|
if (e.key === "Escape")
|
||||||
|
fetch('/api/cancel').catch(err => {
|
||||||
|
console.error(err);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!config.gfpgan_model_exists) {
|
||||||
|
document.querySelector("#gfpgan").style.display = 'none';
|
||||||
|
}
|
||||||
|
|
||||||
|
window.addEventListener("scroll", () => {
|
||||||
|
if ((window.innerHeight + window.pageYOffset) >= document.body.offsetHeight) {
|
||||||
|
loadPriorResults();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Enable/disable forms by checkboxes
|
||||||
|
document.querySelectorAll("legend > input[type=checkbox]").forEach(function(cb) {
|
||||||
|
cb.addEventListener('change', fieldSetEnableChecked);
|
||||||
|
fieldSetEnableChecked({ target: cb})
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
// Load some of the previous results
|
||||||
|
loadPriorResults();
|
||||||
|
|
||||||
|
// Image drop/upload WIP
|
||||||
|
/*
|
||||||
|
let drop = document.getElementById('dropper');
|
||||||
|
function ondrop(event) {
|
||||||
|
let dreamData = event.dataTransfer.getData('dream');
|
||||||
|
if (dreamData) {
|
||||||
|
var dream = JSON.parse(decodeURIComponent(dreamData));
|
||||||
|
alert(dream.dreamId);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
function ondragenter(event) {
|
||||||
|
event.preventDefault();
|
||||||
|
};
|
||||||
|
|
||||||
|
function ondragover(event) {
|
||||||
|
event.preventDefault();
|
||||||
|
};
|
||||||
|
|
||||||
|
function ondragleave(event) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
drop.addEventListener('drop', ondrop);
|
||||||
|
drop.addEventListener('dragenter', ondragenter);
|
||||||
|
drop.addEventListener('dragover', ondragover);
|
||||||
|
drop.addEventListener('dragleave', ondragleave);
|
||||||
|
*/
|
||||||
|
};
|
BIN
static/legacy_web/favicon.ico
Normal file
BIN
static/legacy_web/favicon.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.1 KiB |
152
static/legacy_web/index.css
Normal file
152
static/legacy_web/index.css
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
* {
|
||||||
|
font-family: 'Arial';
|
||||||
|
font-size: 100%;
|
||||||
|
}
|
||||||
|
body {
|
||||||
|
font-size: 1em;
|
||||||
|
}
|
||||||
|
textarea {
|
||||||
|
font-size: 0.95em;
|
||||||
|
}
|
||||||
|
header, form, #progress-section {
|
||||||
|
margin-left: auto;
|
||||||
|
margin-right: auto;
|
||||||
|
max-width: 1024px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
fieldset {
|
||||||
|
border: none;
|
||||||
|
line-height: 2.2em;
|
||||||
|
}
|
||||||
|
select, input {
|
||||||
|
margin-right: 10px;
|
||||||
|
padding: 2px;
|
||||||
|
}
|
||||||
|
input[type=submit] {
|
||||||
|
background-color: #666;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
input[type=checkbox] {
|
||||||
|
margin-right: 0px;
|
||||||
|
width: 20px;
|
||||||
|
height: 20px;
|
||||||
|
vertical-align: middle;
|
||||||
|
}
|
||||||
|
input#seed {
|
||||||
|
margin-right: 0px;
|
||||||
|
}
|
||||||
|
div {
|
||||||
|
padding: 10px 10px 10px 10px;
|
||||||
|
}
|
||||||
|
header {
|
||||||
|
margin-bottom: 16px;
|
||||||
|
}
|
||||||
|
header h1 {
|
||||||
|
margin-bottom: 0;
|
||||||
|
font-size: 2em;
|
||||||
|
}
|
||||||
|
#search-box {
|
||||||
|
display: flex;
|
||||||
|
}
|
||||||
|
#scaling-inprocess-message {
|
||||||
|
font-weight: bold;
|
||||||
|
font-style: italic;
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
#prompt {
|
||||||
|
flex-grow: 1;
|
||||||
|
padding: 5px 10px 5px 10px;
|
||||||
|
border: 1px solid #999;
|
||||||
|
outline: none;
|
||||||
|
}
|
||||||
|
#submit {
|
||||||
|
padding: 5px 10px 5px 10px;
|
||||||
|
border: 1px solid #999;
|
||||||
|
}
|
||||||
|
#reset-all, #remove-image {
|
||||||
|
margin-top: 12px;
|
||||||
|
font-size: 0.8em;
|
||||||
|
background-color: pink;
|
||||||
|
border: 1px solid #999;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
#results {
|
||||||
|
text-align: center;
|
||||||
|
margin: auto;
|
||||||
|
padding-top: 10px;
|
||||||
|
}
|
||||||
|
#results figure {
|
||||||
|
display: inline-block;
|
||||||
|
margin: 10px;
|
||||||
|
}
|
||||||
|
#results figcaption {
|
||||||
|
font-size: 0.8em;
|
||||||
|
padding: 3px;
|
||||||
|
color: #888;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
#results img {
|
||||||
|
border-radius: 5px;
|
||||||
|
object-fit: cover;
|
||||||
|
}
|
||||||
|
#fieldset-config {
|
||||||
|
line-height:2em;
|
||||||
|
background-color: #F0F0F0;
|
||||||
|
}
|
||||||
|
input[type="number"] {
|
||||||
|
width: 60px;
|
||||||
|
}
|
||||||
|
#seed {
|
||||||
|
width: 150px;
|
||||||
|
}
|
||||||
|
button#reset-seed {
|
||||||
|
font-size: 1.7em;
|
||||||
|
background: #efefef;
|
||||||
|
border: 1px solid #999;
|
||||||
|
border-radius: 4px;
|
||||||
|
line-height: 0.8;
|
||||||
|
margin: 0 10px 0 0;
|
||||||
|
padding: 0 5px 3px;
|
||||||
|
vertical-align: middle;
|
||||||
|
}
|
||||||
|
label {
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
#progress-section {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
#progress-image {
|
||||||
|
width: 30vh;
|
||||||
|
height: 30vh;
|
||||||
|
}
|
||||||
|
#cancel-button {
|
||||||
|
cursor: pointer;
|
||||||
|
color: red;
|
||||||
|
}
|
||||||
|
#basic-parameters {
|
||||||
|
background-color: #EEEEEE;
|
||||||
|
}
|
||||||
|
#txt2img {
|
||||||
|
background-color: #DCDCDC;
|
||||||
|
}
|
||||||
|
#variations {
|
||||||
|
background-color: #EEEEEE;
|
||||||
|
}
|
||||||
|
#img2img {
|
||||||
|
background-color: #DCDCDC;
|
||||||
|
}
|
||||||
|
#gfpgan {
|
||||||
|
background-color: #EEEEEE;
|
||||||
|
}
|
||||||
|
#progress-section {
|
||||||
|
background-color: #F5F5F5;
|
||||||
|
}
|
||||||
|
.section-header {
|
||||||
|
text-align: left;
|
||||||
|
font-weight: bold;
|
||||||
|
padding: 0 0 0 0;
|
||||||
|
}
|
||||||
|
#no-results-message:not(:only-child) {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
137
static/legacy_web/index.html
Normal file
137
static/legacy_web/index.html
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<title>Stable Diffusion Dream Server</title>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<link rel="icon" type="image/x-icon" href="static/legacy_web/favicon.ico" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<link rel="stylesheet" href="static/legacy_web/index.css">
|
||||||
|
<script src="config.js"></script>
|
||||||
|
<script src="static/legacy_web/index.js"></script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<header>
|
||||||
|
<h1>Stable Diffusion Dream Server</h1>
|
||||||
|
<div id="about">
|
||||||
|
For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a>
|
||||||
|
</div>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
<main>
|
||||||
|
<form id="generate-form" method="post" action="#">
|
||||||
|
<fieldset id="txt2img">
|
||||||
|
<div id="search-box">
|
||||||
|
<textarea rows="3" id="prompt" name="prompt"></textarea>
|
||||||
|
<input type="submit" id="submit" value="Generate">
|
||||||
|
</div>
|
||||||
|
</fieldset>
|
||||||
|
<fieldset id="fieldset-config">
|
||||||
|
<div class="section-header">Basic options</div>
|
||||||
|
<label for="iterations">Images to generate:</label>
|
||||||
|
<input value="1" type="number" id="iterations" name="iterations" size="4">
|
||||||
|
<label for="steps">Steps:</label>
|
||||||
|
<input value="50" type="number" id="steps" name="steps">
|
||||||
|
<label for="cfg_scale">Cfg Scale:</label>
|
||||||
|
<input value="7.5" type="number" id="cfg_scale" name="cfg_scale" step="any">
|
||||||
|
<label for="sampler_name">Sampler:</label>
|
||||||
|
<select id="sampler_name" name="sampler_name" value="k_lms">
|
||||||
|
<option value="ddim">DDIM</option>
|
||||||
|
<option value="plms">PLMS</option>
|
||||||
|
<option value="k_lms" selected>KLMS</option>
|
||||||
|
<option value="k_dpm_2">KDPM_2</option>
|
||||||
|
<option value="k_dpm_2_a">KDPM_2A</option>
|
||||||
|
<option value="k_dpmpp_2">KDPMPP_2</option>
|
||||||
|
<option value="k_dpmpp_2_a">KDPMPP_2A</option>
|
||||||
|
<option value="k_euler">KEULER</option>
|
||||||
|
<option value="k_euler_a">KEULER_A</option>
|
||||||
|
<option value="k_heun">KHEUN</option>
|
||||||
|
</select>
|
||||||
|
<input type="checkbox" name="seamless" id="seamless">
|
||||||
|
<label for="seamless">Seamless circular tiling</label>
|
||||||
|
<br>
|
||||||
|
<label title="Set to multiple of 64" for="width">Width:</label>
|
||||||
|
<select id="width" name="width" value="512">
|
||||||
|
<option value="64">64</option> <option value="128">128</option>
|
||||||
|
<option value="192">192</option> <option value="256">256</option>
|
||||||
|
<option value="320">320</option> <option value="384">384</option>
|
||||||
|
<option value="448">448</option> <option value="512" selected>512</option>
|
||||||
|
<option value="576">576</option> <option value="640">640</option>
|
||||||
|
<option value="704">704</option> <option value="768">768</option>
|
||||||
|
<option value="832">832</option> <option value="896">896</option>
|
||||||
|
<option value="960">960</option> <option value="1024">1024</option>
|
||||||
|
</select>
|
||||||
|
<label title="Set to multiple of 64" for="height">Height:</label>
|
||||||
|
<select id="height" name="height" value="512">
|
||||||
|
<option value="64">64</option> <option value="128">128</option>
|
||||||
|
<option value="192">192</option> <option value="256">256</option>
|
||||||
|
<option value="320">320</option> <option value="384">384</option>
|
||||||
|
<option value="448">448</option> <option value="512" selected>512</option>
|
||||||
|
<option value="576">576</option> <option value="640">640</option>
|
||||||
|
<option value="704">704</option> <option value="768">768</option>
|
||||||
|
<option value="832">832</option> <option value="896">896</option>
|
||||||
|
<option value="960">960</option> <option value="1024">1024</option>
|
||||||
|
</select>
|
||||||
|
<label title="Set to -1 for random seed" for="seed">Seed:</label>
|
||||||
|
<input value="-1" type="number" id="seed" name="seed">
|
||||||
|
<button type="button" id="reset-seed">↺</button>
|
||||||
|
<input type="checkbox" name="progress_images" id="progress_images">
|
||||||
|
<label for="progress_images">Display in-progress images (slower)</label>
|
||||||
|
<div>
|
||||||
|
<label title="If > 0, adds thresholding to restrict values for k-diffusion samplers (0 disables)" for="threshold">Threshold:</label>
|
||||||
|
<input value="0" type="number" id="threshold" name="threshold" step="0.1" min="0">
|
||||||
|
<label title="Perlin: optional 0-1 value adds a percentage of perlin noise to the initial noise" for="perlin">Perlin:</label>
|
||||||
|
<input value="0" type="number" id="perlin" name="perlin" step="0.01" min="0" max="1">
|
||||||
|
<button type="button" id="reset-all">Reset to Defaults</button>
|
||||||
|
</div>
|
||||||
|
<span id="variations">
|
||||||
|
<label title="If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different." for="variation_amount">Variation amount (0 to disable):</label>
|
||||||
|
<input value="0" type="number" id="variation_amount" name="variation_amount" step="0.01" min="0" max="1">
|
||||||
|
<label title="list of variations to apply, in the format `seed:weight,seed:weight,..." for="with_variations">With variations (seed:weight,seed:weight,...):</label>
|
||||||
|
<input value="" type="text" id="with_variations" name="with_variations">
|
||||||
|
</span>
|
||||||
|
</fieldset>
|
||||||
|
<fieldset id="img2img">
|
||||||
|
<div class="section-header">Image-to-image options</div>
|
||||||
|
<label title="Upload an image to use img2img" for="initimg">Initial image:</label>
|
||||||
|
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
|
||||||
|
<button type="button" id="remove-image">Remove Image</button>
|
||||||
|
<br>
|
||||||
|
<label for="strength">Img2Img Strength:</label>
|
||||||
|
<input value="0.75" type="number" id="strength" name="strength" step="0.01" min="0" max="1">
|
||||||
|
<input type="checkbox" id="fit" name="fit" checked>
|
||||||
|
<label title="Rescale image to fit within requested width and height" for="fit">Fit to width/height</label>
|
||||||
|
</fieldset>
|
||||||
|
<fieldset id="gfpgan">
|
||||||
|
<div class="section-header">Post-processing options</div>
|
||||||
|
<label title="Strength of the gfpgan (face fixing) algorithm." for="facetool_strength">GPFGAN Strength (0 to disable):</label>
|
||||||
|
<input value="0.0" min="0" max="1" type="number" id="facetool_strength" name="facetool_strength" step="0.1">
|
||||||
|
<label title="Upscaling to perform using ESRGAN." for="upscale_level">Upscaling Level</label>
|
||||||
|
<select id="upscale_level" name="upscale_level" value="">
|
||||||
|
<option value="" selected>None</option>
|
||||||
|
<option value="2">2x</option>
|
||||||
|
<option value="4">4x</option>
|
||||||
|
</select>
|
||||||
|
<label title="Strength of the esrgan (upscaling) algorithm." for="upscale_strength">Upscale Strength:</label>
|
||||||
|
<input value="0.75" min="0" max="1" type="number" id="upscale_strength" name="upscale_strength" step="0.05">
|
||||||
|
</fieldset>
|
||||||
|
</form>
|
||||||
|
<br>
|
||||||
|
<section id="progress-section">
|
||||||
|
<div id="progress-container">
|
||||||
|
<progress id="progress-bar" value="0" max="1"></progress>
|
||||||
|
<span id="cancel-button" title="Cancel">✖</span>
|
||||||
|
<br>
|
||||||
|
<img id="progress-image" src='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'>
|
||||||
|
<div id="scaling-inprocess-message">
|
||||||
|
<i><span>Postprocessing...</span><span id="processing_cnt">1/3</span></i>
|
||||||
|
</div>
|
||||||
|
</span>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<div id="results">
|
||||||
|
<div id="no-results-message">
|
||||||
|
<i><p>No results...</p></i>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</main>
|
||||||
|
</body>
|
||||||
|
</html>
|
213
static/legacy_web/index.js
Normal file
213
static/legacy_web/index.js
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
function toBase64(file) {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const r = new FileReader();
|
||||||
|
r.readAsDataURL(file);
|
||||||
|
r.onload = () => resolve(r.result);
|
||||||
|
r.onerror = (error) => reject(error);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function appendOutput(src, seed, config) {
|
||||||
|
let outputNode = document.createElement("figure");
|
||||||
|
|
||||||
|
let variations = config.with_variations;
|
||||||
|
if (config.variation_amount > 0) {
|
||||||
|
variations = (variations ? variations + ',' : '') + seed + ':' + config.variation_amount;
|
||||||
|
}
|
||||||
|
let baseseed = (config.with_variations || config.variation_amount > 0) ? config.seed : seed;
|
||||||
|
let altText = baseseed + ' | ' + (variations ? variations + ' | ' : '') + config.prompt;
|
||||||
|
|
||||||
|
// img needs width and height for lazy loading to work
|
||||||
|
const figureContents = `
|
||||||
|
<a href="${src}" target="_blank">
|
||||||
|
<img src="${src}"
|
||||||
|
alt="${altText}"
|
||||||
|
title="${altText}"
|
||||||
|
loading="lazy"
|
||||||
|
width="256"
|
||||||
|
height="256">
|
||||||
|
</a>
|
||||||
|
<figcaption>${seed}</figcaption>
|
||||||
|
`;
|
||||||
|
|
||||||
|
outputNode.innerHTML = figureContents;
|
||||||
|
let figcaption = outputNode.querySelector('figcaption');
|
||||||
|
|
||||||
|
// Reload image config
|
||||||
|
figcaption.addEventListener('click', () => {
|
||||||
|
let form = document.querySelector("#generate-form");
|
||||||
|
for (const [k, v] of new FormData(form)) {
|
||||||
|
if (k == 'initimg') { continue; }
|
||||||
|
form.querySelector(`*[name=${k}]`).value = config[k];
|
||||||
|
}
|
||||||
|
|
||||||
|
document.querySelector("#seed").value = baseseed;
|
||||||
|
document.querySelector("#with_variations").value = variations || '';
|
||||||
|
if (document.querySelector("#variation_amount").value <= 0) {
|
||||||
|
document.querySelector("#variation_amount").value = 0.2;
|
||||||
|
}
|
||||||
|
|
||||||
|
saveFields(document.querySelector("#generate-form"));
|
||||||
|
});
|
||||||
|
|
||||||
|
document.querySelector("#results").prepend(outputNode);
|
||||||
|
}
|
||||||
|
|
||||||
|
function saveFields(form) {
|
||||||
|
for (const [k, v] of new FormData(form)) {
|
||||||
|
if (typeof v !== 'object') { // Don't save 'file' type
|
||||||
|
localStorage.setItem(k, v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function loadFields(form) {
|
||||||
|
for (const [k, v] of new FormData(form)) {
|
||||||
|
const item = localStorage.getItem(k);
|
||||||
|
if (item != null) {
|
||||||
|
form.querySelector(`*[name=${k}]`).value = item;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function clearFields(form) {
|
||||||
|
localStorage.clear();
|
||||||
|
let prompt = form.prompt.value;
|
||||||
|
form.reset();
|
||||||
|
form.prompt.value = prompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
const BLANK_IMAGE_URL = 'data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>';
|
||||||
|
async function generateSubmit(form) {
|
||||||
|
const prompt = document.querySelector("#prompt").value;
|
||||||
|
|
||||||
|
// Convert file data to base64
|
||||||
|
let formData = Object.fromEntries(new FormData(form));
|
||||||
|
formData.initimg_name = formData.initimg.name
|
||||||
|
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
|
||||||
|
|
||||||
|
let strength = formData.strength;
|
||||||
|
let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps;
|
||||||
|
|
||||||
|
let progressSectionEle = document.querySelector('#progress-section');
|
||||||
|
progressSectionEle.style.display = 'initial';
|
||||||
|
let progressEle = document.querySelector('#progress-bar');
|
||||||
|
progressEle.setAttribute('max', totalSteps);
|
||||||
|
let progressImageEle = document.querySelector('#progress-image');
|
||||||
|
progressImageEle.src = BLANK_IMAGE_URL;
|
||||||
|
|
||||||
|
progressImageEle.style.display = {}.hasOwnProperty.call(formData, 'progress_images') ? 'initial': 'none';
|
||||||
|
|
||||||
|
// Post as JSON, using Fetch streaming to get results
|
||||||
|
fetch(form.action, {
|
||||||
|
method: form.method,
|
||||||
|
body: JSON.stringify(formData),
|
||||||
|
}).then(async (response) => {
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
|
||||||
|
let noOutputs = true;
|
||||||
|
while (true) {
|
||||||
|
let {value, done} = await reader.read();
|
||||||
|
value = new TextDecoder().decode(value);
|
||||||
|
if (done) {
|
||||||
|
progressSectionEle.style.display = 'none';
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (let event of value.split('\n').filter(e => e !== '')) {
|
||||||
|
const data = JSON.parse(event);
|
||||||
|
|
||||||
|
if (data.event === 'result') {
|
||||||
|
noOutputs = false;
|
||||||
|
appendOutput(data.url, data.seed, data.config);
|
||||||
|
progressEle.setAttribute('value', 0);
|
||||||
|
progressEle.setAttribute('max', totalSteps);
|
||||||
|
} else if (data.event === 'upscaling-started') {
|
||||||
|
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
|
||||||
|
document.getElementById("scaling-inprocess-message").style.display = "block";
|
||||||
|
} else if (data.event === 'upscaling-done') {
|
||||||
|
document.getElementById("scaling-inprocess-message").style.display = "none";
|
||||||
|
} else if (data.event === 'step') {
|
||||||
|
progressEle.setAttribute('value', data.step);
|
||||||
|
if (data.url) {
|
||||||
|
progressImageEle.src = data.url;
|
||||||
|
}
|
||||||
|
} else if (data.event === 'canceled') {
|
||||||
|
// avoid alerting as if this were an error case
|
||||||
|
noOutputs = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-enable form, remove no-results-message
|
||||||
|
form.querySelector('fieldset').removeAttribute('disabled');
|
||||||
|
document.querySelector("#prompt").value = prompt;
|
||||||
|
document.querySelector('progress').setAttribute('value', '0');
|
||||||
|
|
||||||
|
if (noOutputs) {
|
||||||
|
alert("Error occurred while generating.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Disable form while generating
|
||||||
|
form.querySelector('fieldset').setAttribute('disabled','');
|
||||||
|
document.querySelector("#prompt").value = `Generating: "${prompt}"`;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function fetchRunLog() {
|
||||||
|
try {
|
||||||
|
let response = await fetch('/run_log.json')
|
||||||
|
const data = await response.json();
|
||||||
|
for(let item of data.run_log) {
|
||||||
|
appendOutput(item.url, item.seed, item);
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
window.onload = async () => {
|
||||||
|
document.querySelector("#prompt").addEventListener("keydown", (e) => {
|
||||||
|
if (e.key === "Enter" && !e.shiftKey) {
|
||||||
|
const form = e.target.form;
|
||||||
|
generateSubmit(form);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
document.querySelector("#generate-form").addEventListener('submit', (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
const form = e.target;
|
||||||
|
|
||||||
|
generateSubmit(form);
|
||||||
|
});
|
||||||
|
document.querySelector("#generate-form").addEventListener('change', (e) => {
|
||||||
|
saveFields(e.target.form);
|
||||||
|
});
|
||||||
|
document.querySelector("#reset-seed").addEventListener('click', (e) => {
|
||||||
|
document.querySelector("#seed").value = -1;
|
||||||
|
saveFields(e.target.form);
|
||||||
|
});
|
||||||
|
document.querySelector("#reset-all").addEventListener('click', (e) => {
|
||||||
|
clearFields(e.target.form);
|
||||||
|
});
|
||||||
|
document.querySelector("#remove-image").addEventListener('click', (e) => {
|
||||||
|
initimg.value=null;
|
||||||
|
});
|
||||||
|
loadFields(document.querySelector("#generate-form"));
|
||||||
|
|
||||||
|
document.querySelector('#cancel-button').addEventListener('click', () => {
|
||||||
|
fetch('/cancel').catch(e => {
|
||||||
|
console.error(e);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
document.documentElement.addEventListener('keydown', (e) => {
|
||||||
|
if (e.key === "Escape")
|
||||||
|
fetch('/cancel').catch(err => {
|
||||||
|
console.error(err);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!config.gfpgan_model_exists) {
|
||||||
|
document.querySelector("#gfpgan").style.display = 'none';
|
||||||
|
}
|
||||||
|
await fetchRunLog()
|
||||||
|
};
|
@ -21,12 +21,13 @@ def simple_graph():
|
|||||||
def mock_services():
|
def mock_services():
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
generate = None,
|
model_manager = None,
|
||||||
events = None,
|
events = None,
|
||||||
images = None,
|
images = None,
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor()
|
processor = DefaultInvocationProcessor(),
|
||||||
|
restoration = None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
||||||
|
@ -21,12 +21,13 @@ def simple_graph():
|
|||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
generate = None, # type: ignore
|
model_manager = None, # type: ignore
|
||||||
events = TestEventService(),
|
events = TestEventService(),
|
||||||
images = None, # type: ignore
|
images = None, # type: ignore
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor()
|
processor = DefaultInvocationProcessor(),
|
||||||
|
restoration = None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
|
Loading…
Reference in New Issue
Block a user