delete old 'server' package and the dependency_injector requirement (#2032)

fixes #1944
This commit is contained in:
Eugene Brodsky 2022-12-16 06:28:16 -05:00 committed by GitHub
parent ffa54f4a35
commit 7d09d9da49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 0 additions and 1021 deletions

View File

@ -30,7 +30,6 @@ dependencies:
- torchvision - torchvision
- transformers=4.21.3 - transformers=4.21.3
- pip: - pip:
- dependency_injector==4.40.0
- getpass_asterisk - getpass_asterisk
- omegaconf==2.1.1 - omegaconf==2.1.1
- picklescan - picklescan

View File

@ -10,7 +10,6 @@ dependencies:
- pip: - pip:
- --extra-index-url https://download.pytorch.org/whl/rocm5.2/ - --extra-index-url https://download.pytorch.org/whl/rocm5.2/
- albumentations==0.4.3 - albumentations==0.4.3
- dependency_injector==4.40.0
- diffusers==0.6.0 - diffusers==0.6.0
- einops==0.3.0 - einops==0.3.0
- eventlet - eventlet

View File

@ -13,7 +13,6 @@ dependencies:
- cudatoolkit=11.6 - cudatoolkit=11.6
- pip: - pip:
- albumentations==0.4.3 - albumentations==0.4.3
- dependency_injector==4.40.0
- diffusers==0.6.0 - diffusers==0.6.0
- einops==0.3.0 - einops==0.3.0
- eventlet - eventlet

View File

@ -13,7 +13,6 @@ dependencies:
- cudatoolkit=11.6 - cudatoolkit=11.6
- pip: - pip:
- albumentations==0.4.3 - albumentations==0.4.3
- dependency_injector==4.40.0
- diffusers==0.6.0 - diffusers==0.6.0
- einops==0.3.0 - einops==0.3.0
- eventlet - eventlet

View File

@ -1,6 +1,5 @@
# pip will resolve the version which matches torch # pip will resolve the version which matches torch
albumentations albumentations
dependency_injector==4.40.0
diffusers==0.10.* diffusers==0.10.*
einops einops
eventlet eventlet

View File

View File

@ -1,152 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
"""Application module."""
import argparse
import json
import os
import sys
from flask import Flask
from flask_cors import CORS
from flask_socketio import SocketIO
from omegaconf import OmegaConf
from dependency_injector.wiring import inject, Provide
from ldm.invoke.args import Args
from server import views
from server.containers import Container
from server.services import GeneratorService, SignalService
# The socketio_service is injected here (rather than created in run_app) to initialize it
@inject
def initialize_app(
app: Flask,
socketio: SocketIO = Provide[Container.socketio]
) -> SocketIO:
socketio.init_app(app)
return socketio
# The signal and generator services are injected to warm up the processing queues
# TODO: Initialize these a better way?
@inject
def initialize_generator(
signal_service: SignalService = Provide[Container.signal_service],
generator_service: GeneratorService = Provide[Container.generator_service]
):
pass
def run_app(config, host, port) -> Flask:
app = Flask(__name__, static_url_path='')
# Set up dependency injection container
container = Container()
container.config.from_dict(config)
container.wire(modules=[__name__])
app.container = container
# Set up CORS
CORS(app, resources={r'/api/*': {'origins': '*'}})
# Web Routes
app.add_url_rule('/', view_func=views.WebIndex.as_view('web_index', 'index.html'))
app.add_url_rule('/index.css', view_func=views.WebIndex.as_view('web_index_css', 'index.css'))
app.add_url_rule('/index.js', view_func=views.WebIndex.as_view('web_index_js', 'index.js'))
app.add_url_rule('/config.js', view_func=views.WebConfig.as_view('web_config'))
# API Routes
app.add_url_rule('/api/jobs', view_func=views.ApiJobs.as_view('api_jobs'))
app.add_url_rule('/api/cancel', view_func=views.ApiCancel.as_view('api_cancel'))
# TODO: Get storage root from config
app.add_url_rule('/api/images/<string:dreamId>', view_func=views.ApiImages.as_view('api_images', '../'))
app.add_url_rule('/api/images/<string:dreamId>/metadata', view_func=views.ApiImagesMetadata.as_view('api_images_metadata', '../'))
app.add_url_rule('/api/images', view_func=views.ApiImagesList.as_view('api_images_list'))
app.add_url_rule('/api/intermediates/<string:dreamId>/<string:step>', view_func=views.ApiIntermediates.as_view('api_intermediates', '../'))
app.static_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../static/dream_web/'))
# Initialize
socketio = initialize_app(app)
initialize_generator()
print(">> Started Stable Diffusion api server!")
if host == '0.0.0.0':
print(f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
else:
print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
print(f">> Point your browser at http://{host}:{port}.")
# Run the app
socketio.run(app, host, port)
def main():
"""Initialize command-line parsers and the diffusion model"""
arg_parser = Args()
opt = arg_parser.parse_args()
if opt.laion400m:
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
sys.exit(-1)
if opt.weights:
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
sys.exit(-1)
# try:
# models = OmegaConf.load(opt.config)
# width = models[opt.model].width
# height = models[opt.model].height
# config = models[opt.model].config
# weights = models[opt.model].weights
# except (FileNotFoundError, IOError, KeyError) as e:
# print(f'{e}. Aborting.')
# sys.exit(-1)
#print('* Initializing, be patient...\n')
sys.path.append('.')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers.logging.set_verbosity_error()
appConfig = opt.__dict__
# appConfig = {
# "model": {
# "width": width,
# "height": height,
# "sampler_name": opt.sampler_name,
# "weights": weights,
# "precision": opt.precision,
# "config": config,
# "grid": opt.grid,
# "latent_diffusion_weights": opt.laion400m,
# "embedding_path": opt.embedding_path
# }
# }
# make sure the output directory exists
if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir)
# gets rid of annoying messages about random seed
from pytorch_lightning import logging
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
print('\n* starting api server...')
# Change working directory to the stable-diffusion directory
os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
)
# Start server
try:
run_app(appConfig, opt.host, opt.port)
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()

View File

@ -1,81 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
"""Containers module."""
from dependency_injector import containers, providers
from flask_socketio import SocketIO
from ldm.generate import Generate
from server import services
class Container(containers.DeclarativeContainer):
wiring_config = containers.WiringConfiguration(packages=['server'])
config = providers.Configuration()
socketio = providers.ThreadSafeSingleton(
SocketIO,
app = None
)
# TODO: Add a model provider service that provides model(s) dynamically
model_singleton = providers.ThreadSafeSingleton(
Generate,
model = config.model,
sampler_name = config.sampler_name,
embedding_path = config.embedding_path,
precision = config.precision
# config = config.model.config,
# width = config.model.width,
# height = config.model.height,
# sampler_name = config.model.sampler_name,
# weights = config.model.weights,
# precision = config.model.precision,
# grid = config.model.grid,
# seamless = config.model.seamless,
# embedding_path = config.model.embedding_path,
# device_type = config.model.device_type
)
# TODO: get location from config
image_storage_service = providers.ThreadSafeSingleton(
services.ImageStorageService,
'./outputs/img-samples/'
)
# TODO: get location from config
image_intermediates_storage_service = providers.ThreadSafeSingleton(
services.ImageStorageService,
'./outputs/intermediates/'
)
signal_queue_service = providers.ThreadSafeSingleton(
services.SignalQueueService
)
signal_service = providers.ThreadSafeSingleton(
services.SignalService,
socketio = socketio,
queue = signal_queue_service
)
generation_queue_service = providers.ThreadSafeSingleton(
services.JobQueueService
)
# TODO: get locations from config
log_service = providers.ThreadSafeSingleton(
services.LogService,
'./outputs/img-samples/',
'dream_web_log.txt'
)
generator_service = providers.ThreadSafeSingleton(
services.GeneratorService,
model = model_singleton,
queue = generation_queue_service,
imageStorage = image_storage_service,
intermediateStorage = image_intermediates_storage_service,
log = log_service,
signal_service = signal_service
)

View File

@ -1,259 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from base64 import urlsafe_b64encode
import json
import string
from copy import deepcopy
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Union
from uuid import uuid4
class DreamBase():
# Id
id: str
# Initial Image
enable_init_image: bool
initimg: string = None
# Img2Img
enable_img2img: bool # TODO: support this better
strength: float = 0 # TODO: name this something related to img2img to make it clearer?
fit = None # Fit initial image dimensions
# Generation
enable_generate: bool
prompt: string = ""
seed: int = 0 # 0 is random
steps: int = 10
width: int = 512
height: int = 512
cfg_scale: float = 7.5
threshold: float = 0.0
perlin: float = 0.0
sampler_name: string = 'klms'
seamless: bool = False
hires_fix: bool = False
model: str = None # The model to use (currently unused)
embeddings = None # The embeddings to use (currently unused)
progress_images: bool = False
progress_latents: bool = False
# GFPGAN
enable_gfpgan: bool
facetool_strength: float = 0
# Upscale
enable_upscale: bool
upscale: None
upscale_level: int = None
upscale_strength: float = 0.75
# Embiggen
enable_embiggen: bool
embiggen: Union[None, List[float]] = None
embiggen_tiles: Union[None, List[int]] = None
# Metadata
time: int
def __init__(self):
self.id = urlsafe_b64encode(uuid4().bytes).decode('ascii')
def parse_json(self, j, new_instance=False):
# Id
if 'id' in j and not new_instance:
self.id = j.get('id')
# Initial Image
self.enable_init_image = 'enable_init_image' in j and bool(j.get('enable_init_image'))
if self.enable_init_image:
self.initimg = j.get('initimg')
# Img2Img
self.enable_img2img = 'enable_img2img' in j and bool(j.get('enable_img2img'))
if self.enable_img2img:
self.strength = float(j.get('strength'))
self.fit = 'fit' in j
# Generation
self.enable_generate = 'enable_generate' in j and bool(j.get('enable_generate'))
if self.enable_generate:
self.prompt = j.get('prompt')
self.seed = int(j.get('seed'))
self.steps = int(j.get('steps'))
self.width = int(j.get('width'))
self.height = int(j.get('height'))
self.cfg_scale = float(j.get('cfgscale') or j.get('cfg_scale'))
self.threshold = float(j.get('threshold'))
self.perlin = float(j.get('perlin'))
self.sampler_name = j.get('sampler') or j.get('sampler_name')
# model: str = None # The model to use (currently unused)
# embeddings = None # The embeddings to use (currently unused)
self.seamless = 'seamless' in j
self.hires_fix = 'hires_fix' in j
self.progress_images = 'progress_images' in j
self.progress_latents = 'progress_latents' in j
# GFPGAN
self.enable_gfpgan = 'enable_gfpgan' in j and bool(j.get('enable_gfpgan'))
if self.enable_gfpgan:
self.facetool_strength = float(j.get('facetool_strength'))
# Upscale
self.enable_upscale = 'enable_upscale' in j and bool(j.get('enable_upscale'))
if self.enable_upscale:
self.upscale_level = j.get('upscale_level')
self.upscale_strength = j.get('upscale_strength')
self.upscale = None if self.upscale_level in {None,''} else [int(self.upscale_level),float(self.upscale_strength)]
# Embiggen
self.enable_embiggen = 'enable_embiggen' in j and bool(j.get('enable_embiggen'))
if self.enable_embiggen:
self.embiggen = j.get('embiggen')
self.embiggen_tiles = j.get('embiggen_tiles')
# Metadata
self.time = int(j.get('time')) if ('time' in j and not new_instance) else int(datetime.now(timezone.utc).timestamp())
class DreamResult(DreamBase):
# Result
has_upscaled: False
has_gfpgan: False
# TODO: use something else for state tracking
images_generated: int = 0
images_upscaled: int = 0
def __init__(self):
super().__init__()
def clone_without_img(self):
copy = deepcopy(self)
copy.initimg = None
return copy
def to_json(self):
copy = deepcopy(self)
copy.initimg = None
j = json.dumps(copy.__dict__)
return j
@staticmethod
def from_json(j, newTime: bool = False):
d = DreamResult()
d.parse_json(j)
return d
# TODO: switch this to a pipelined request, with pluggable steps
# Will likely require generator code changes to accomplish
class JobRequest(DreamBase):
# Iteration
iterations: int = 1
variation_amount = None
with_variations = None
# Results
results: List[DreamResult] = []
def __init__(self):
super().__init__()
def newDreamResult(self) -> DreamResult:
result = DreamResult()
result.parse_json(self.__dict__, new_instance=True)
return result
@staticmethod
def from_json(j):
job = JobRequest()
job.parse_json(j)
# Metadata
job.time = int(j.get('time')) if ('time' in j) else int(datetime.now(timezone.utc).timestamp())
# Iteration
if job.enable_generate:
job.iterations = int(j.get('iterations'))
job.variation_amount = float(j.get('variation_amount'))
job.with_variations = j.get('with_variations')
return job
class ProgressType(Enum):
GENERATION = 1
UPSCALING_STARTED = 2
UPSCALING_DONE = 3
class Signal():
event: str
data = None
room: str = None
broadcast: bool = False
def __init__(self, event: str, data, room: str = None, broadcast: bool = False):
self.event = event
self.data = data
self.room = room
self.broadcast = broadcast
@staticmethod
def image_progress(jobId: str, dreamId: str, step: int, totalSteps: int, progressType: ProgressType = ProgressType.GENERATION, hasProgressImage: bool = False):
return Signal('dream_progress', {
'jobId': jobId,
'dreamId': dreamId,
'step': step,
'totalSteps': totalSteps,
'hasProgressImage': hasProgressImage,
'progressType': progressType.name
}, room=jobId, broadcast=True)
# TODO: use a result id or something? Like a sub-job
@staticmethod
def image_result(jobId: str, dreamId: str, dreamResult: DreamResult):
return Signal('dream_result', {
'jobId': jobId,
'dreamId': dreamId,
'dreamRequest': dreamResult.clone_without_img().__dict__
}, room=jobId, broadcast=True)
@staticmethod
def job_started(jobId: str):
return Signal('job_started', {
'jobId': jobId
}, room=jobId, broadcast=True)
@staticmethod
def job_done(jobId: str):
return Signal('job_done', {
'jobId': jobId
}, room=jobId, broadcast=True)
@staticmethod
def job_canceled(jobId: str):
return Signal('job_canceled', {
'jobId': jobId
}, room=jobId, broadcast=True)
class PaginatedItems():
items: List[Any]
page: int # Current Page
pages: int # Total number of pages
per_page: int # Number of items per page
total: int # Total number of items in result
def __init__(self, items: List[Any], page: int, pages: int, per_page: int, total: int):
self.items = items
self.page = page
self.pages = pages
self.per_page = per_page
self.total = total
def to_json(self):
return json.dumps(self.__dict__)

View File

@ -1,392 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from argparse import ArgumentParser
import base64
from datetime import datetime, timezone
import glob
import json
import os
from pathlib import Path
from queue import Empty, Queue
import shlex
from threading import Thread
import time
from flask_socketio import SocketIO, join_room, leave_room
from ldm.invoke.args import Args
from ldm.invoke.generator import embiggen
from PIL import Image
from ldm.invoke.pngwriter import PngWriter
from ldm.invoke.server import CanceledException
from ldm.generate import Generate
from server.models import DreamResult, JobRequest, PaginatedItems, ProgressType, Signal
class JobQueueService:
__queue: Queue = Queue()
def push(self, dreamRequest: DreamResult):
self.__queue.put(dreamRequest)
def get(self, timeout: float = None) -> DreamResult:
return self.__queue.get(timeout= timeout)
class SignalQueueService:
__queue: Queue = Queue()
def push(self, signal: Signal):
self.__queue.put(signal)
def get(self) -> Signal:
return self.__queue.get(block=False)
class SignalService:
__socketio: SocketIO
__queue: SignalQueueService
def __init__(self, socketio: SocketIO, queue: SignalQueueService):
self.__socketio = socketio
self.__queue = queue
def on_join(data):
room = data['room']
join_room(room)
self.__socketio.emit("test", "something", room=room)
def on_leave(data):
room = data['room']
leave_room(room)
self.__socketio.on_event('join_room', on_join)
self.__socketio.on_event('leave_room', on_leave)
self.__socketio.start_background_task(self.__process)
def __process(self):
# preload the model
print('Started signal queue processor')
try:
while True:
try:
signal = self.__queue.get()
self.__socketio.emit(signal.event, signal.data, room=signal.room, broadcast=signal.broadcast)
except Empty:
pass
finally:
self.__socketio.sleep(0.001)
except KeyboardInterrupt:
print('Signal queue processor stopped')
def emit(self, signal: Signal):
self.__queue.push(signal)
# TODO: Name this better?
# TODO: Logging and signals should probably be event based (multiple listeners for an event)
class LogService:
__location: str
__logFile: str
def __init__(self, location:str, file:str):
self.__location = location
self.__logFile = file
def log(self, dreamResult: DreamResult, seed = None, upscaled = False):
with open(os.path.join(self.__location, self.__logFile), "a") as log:
log.write(f"{dreamResult.id}: {dreamResult.to_json()}\n")
class ImageStorageService:
__location: str
__pngWriter: PngWriter
__legacyParser: ArgumentParser
def __init__(self, location):
self.__location = location
self.__pngWriter = PngWriter(self.__location)
self.__legacyParser = Args() # TODO: inject this?
def __getName(self, dreamId: str, postfix: str = '') -> str:
return f'{dreamId}{postfix}.png'
def save(self, image, dreamResult: DreamResult, postfix: str = '') -> str:
name = self.__getName(dreamResult.id, postfix)
meta = dreamResult.to_json() # TODO: make all methods consistent with writing metadata. Standardize metadata.
path = self.__pngWriter.save_image_and_prompt_to_png(image, dream_prompt=meta, metadata=None, name=name)
return path
def path(self, dreamId: str, postfix: str = '') -> str:
name = self.__getName(dreamId, postfix)
path = os.path.join(self.__location, name)
return path
# Returns true if found, false if not found or error
def delete(self, dreamId: str, postfix: str = '') -> bool:
path = self.path(dreamId, postfix)
if (os.path.exists(path)):
os.remove(path)
return True
else:
return False
def getMetadata(self, dreamId: str, postfix: str = '') -> DreamResult:
path = self.path(dreamId, postfix)
image = Image.open(path)
text = image.text
if text.__contains__('Dream'):
dreamMeta = text.get('Dream')
try:
j = json.loads(dreamMeta)
return DreamResult.from_json(j)
except ValueError:
# Try to parse command-line format (legacy metadata format)
try:
opt = self.__parseLegacyMetadata(dreamMeta)
optd = opt.__dict__
if (not 'width' in optd) or (optd.get('width') is None):
optd['width'] = image.width
if (not 'height' in optd) or (optd.get('height') is None):
optd['height'] = image.height
if (not 'steps' in optd) or (optd.get('steps') is None):
optd['steps'] = 10 # No way around this unfortunately - seems like it wasn't storing this previously
optd['time'] = os.path.getmtime(path) # Set timestamp manually (won't be exactly correct though)
return DreamResult.from_json(optd)
except:
return None
else:
return None
def __parseLegacyMetadata(self, command: str) -> DreamResult:
# before splitting, escape single quotes so as not to mess
# up the parser
command = command.replace("'", "\\'")
try:
elements = shlex.split(command)
except ValueError as e:
return None
# rearrange the arguments to mimic how it works in the Dream bot.
switches = ['']
switches_started = False
for el in elements:
if el[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(el)
else:
switches[0] += el
switches[0] += ' '
switches[0] = switches[0][: len(switches[0]) - 1]
try:
opt = self.__legacyParser.parse_cmd(switches)
return opt
except SystemExit:
return None
def list_files(self, page: int, perPage: int) -> PaginatedItems:
files = sorted(glob.glob(os.path.join(self.__location,'*.png')), key=os.path.getmtime, reverse=True)
count = len(files)
startId = page * perPage
pageCount = int(count / perPage) + 1
endId = min(startId + perPage, count)
items = [] if startId >= count else files[startId:endId]
items = list(map(lambda f: Path(f).stem, items))
return PaginatedItems(items, page, pageCount, perPage, count)
class GeneratorService:
__model: Generate
__queue: JobQueueService
__imageStorage: ImageStorageService
__intermediateStorage: ImageStorageService
__log: LogService
__thread: Thread
__cancellationRequested: bool = False
__signal_service: SignalService
def __init__(self, model: Generate, queue: JobQueueService, imageStorage: ImageStorageService, intermediateStorage: ImageStorageService, log: LogService, signal_service: SignalService):
self.__model = model
self.__queue = queue
self.__imageStorage = imageStorage
self.__intermediateStorage = intermediateStorage
self.__log = log
self.__signal_service = signal_service
# Create the background thread
self.__thread = Thread(target=self.__process, name = "GeneratorService")
self.__thread.daemon = True
self.__thread.start()
# Request cancellation of the current job
def cancel(self):
self.__cancellationRequested = True
# TODO: Consider moving this to its own service if there's benefit in separating the generator
def __process(self):
# preload the model
# TODO: support multiple models
print('Preloading model')
tic = time.time()
self.__model.load_model()
print(f'>> model loaded in', '%4.2fs' % (time.time() - tic))
print('Started generation queue processor')
try:
while True:
dreamRequest = self.__queue.get()
self.__generate(dreamRequest)
except KeyboardInterrupt:
print('Generation queue processor stopped')
def __on_start(self, jobRequest: JobRequest):
self.__signal_service.emit(Signal.job_started(jobRequest.id))
def __on_image_result(self, jobRequest: JobRequest, image, seed, upscaled=False):
dreamResult = jobRequest.newDreamResult()
dreamResult.seed = seed
dreamResult.has_upscaled = upscaled
dreamResult.iterations = 1
jobRequest.results.append(dreamResult)
# TODO: Separate status of GFPGAN?
self.__imageStorage.save(image, dreamResult)
# TODO: handle upscaling logic better (this is appending data to log, but only on first generation)
if not upscaled:
self.__log.log(dreamResult)
# Send result signal
self.__signal_service.emit(Signal.image_result(jobRequest.id, dreamResult.id, dreamResult))
upscaling_requested = dreamResult.enable_upscale or dreamResult.enable_gfpgan
# Report upscaling status
# TODO: this is very coupled to logic inside the generator. Fix that.
if upscaling_requested and any(result.has_upscaled for result in jobRequest.results):
progressType = ProgressType.UPSCALING_STARTED if len(jobRequest.results) < 2 * jobRequest.iterations else ProgressType.UPSCALING_DONE
upscale_count = sum(1 for i in jobRequest.results if i.has_upscaled)
self.__signal_service.emit(Signal.image_progress(jobRequest.id, dreamResult.id, upscale_count, jobRequest.iterations, progressType))
def __on_progress(self, jobRequest: JobRequest, sample, step):
if self.__cancellationRequested:
self.__cancellationRequested = False
raise CanceledException
# TODO: Progress per request will be easier once the seeds (and ids) can all be pre-generated
hasProgressImage = False
s = str(len(jobRequest.results))
if jobRequest.progress_images and step % 5 == 0 and step < jobRequest.steps - 1:
image = self.__model._sample_to_image(sample)
# TODO: clean this up, use a pre-defined dream result
result = DreamResult()
result.parse_json(jobRequest.__dict__, new_instance=False)
self.__intermediateStorage.save(image, result, postfix=f'.{s}.{step}')
hasProgressImage = True
self.__signal_service.emit(Signal.image_progress(jobRequest.id, f'{jobRequest.id}.{s}', step, jobRequest.steps, ProgressType.GENERATION, hasProgressImage))
def __generate(self, jobRequest: JobRequest):
try:
# TODO: handle this file a file service for init images
initimgfile = None # TODO: support this on the model directly?
if (jobRequest.enable_init_image):
if jobRequest.initimg is not None:
with open("./img2img-tmp.png", "wb") as f:
initimg = jobRequest.initimg.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))
initimgfile = "./img2img-tmp.png"
# Use previous seed if set to -1
initSeed = jobRequest.seed
if initSeed == -1:
initSeed = self.__model.seed
# Zero gfpgan strength if the model doesn't exist
# TODO: determine if this could be at the top now? Used to cause circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
if not gfpgan_model_exists:
jobRequest.enable_gfpgan = False
# Signal start
self.__on_start(jobRequest)
# Generate in model
# TODO: Split job generation requests instead of fitting all parameters here
# TODO: Support no generation (just upscaling/gfpgan)
upscale = None if not jobRequest.enable_upscale else jobRequest.upscale
facetool_strength = 0 if not jobRequest.enable_gfpgan else jobRequest.facetool_strength
if not jobRequest.enable_generate:
# If not generating, check if we're upscaling or running gfpgan
if not upscale and not facetool_strength:
# Invalid settings (TODO: Add message to help user)
raise CanceledException()
image = Image.open(initimgfile)
# TODO: support progress for upscale?
self.__model.upscale_and_reconstruct(
image_list = [[image,0]],
upscale = upscale,
strength = facetool_strength,
save_original = False,
image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled))
else:
# Generating - run the generation
init_img = None if (not jobRequest.enable_img2img or jobRequest.strength == 0) else initimgfile
self.__model.prompt2image(
prompt = jobRequest.prompt,
init_img = init_img, # TODO: ensure this works
strength = None if init_img is None else jobRequest.strength,
fit = None if init_img is None else jobRequest.fit,
iterations = jobRequest.iterations,
cfg_scale = jobRequest.cfg_scale,
threshold = jobRequest.threshold,
perlin = jobRequest.perlin,
width = jobRequest.width,
height = jobRequest.height,
seed = jobRequest.seed,
steps = jobRequest.steps,
variation_amount = jobRequest.variation_amount,
with_variations = jobRequest.with_variations,
facetool_strength = facetool_strength,
upscale = upscale,
sampler_name = jobRequest.sampler_name,
seamless = jobRequest.seamless,
hires_fix = jobRequest.hires_fix,
embiggen = jobRequest.embiggen,
embiggen_tiles = jobRequest.embiggen_tiles,
step_callback = lambda sample, step: self.__on_progress(jobRequest, sample, step),
image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled))
except CanceledException:
self.__signal_service.emit(Signal.job_canceled(jobRequest.id))
finally:
self.__signal_service.emit(Signal.job_done(jobRequest.id))
# Remove the temp file
if (initimgfile is not None):
os.remove("./img2img-tmp.png")

View File

@ -1,132 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
"""Views module."""
import json
import os
from queue import Queue
from flask import current_app, jsonify, request, Response, send_from_directory, stream_with_context, url_for
from flask.views import MethodView
from dependency_injector.wiring import inject, Provide
from server.models import DreamResult, JobRequest
from server.services import GeneratorService, ImageStorageService, JobQueueService
from server.containers import Container
class ApiJobs(MethodView):
@inject
def post(self, job_queue_service: JobQueueService = Provide[Container.generation_queue_service]):
jobRequest = JobRequest.from_json(request.json)
print(f">> Request to generate with prompt: {jobRequest.prompt}")
# Push the request
job_queue_service.push(jobRequest)
return { 'jobId': jobRequest.id }
class WebIndex(MethodView):
init_every_request = False
__file: str = None
def __init__(self, file):
self.__file = file
def get(self):
return current_app.send_static_file(self.__file)
class WebConfig(MethodView):
init_every_request = False
def get(self):
# unfortunately this import can't be at the top level, since that would cause a circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
config = {
'gfpgan_model_exists': gfpgan_model_exists
}
js = f"let config = {json.dumps(config)};\n"
return Response(js, mimetype="application/javascript")
class ApiCancel(MethodView):
init_every_request = False
@inject
def get(self, generator_service: GeneratorService = Provide[Container.generator_service]):
generator_service.cancel()
return Response(status=204)
# TODO: Combine all image storage access
class ApiImages(MethodView):
init_every_request = False
__pathRoot = None
__storage: ImageStorageService
@inject
def __init__(self, pathBase, storage: ImageStorageService = Provide[Container.image_storage_service]):
self.__pathRoot = os.path.abspath(os.path.join(os.path.dirname(__file__), pathBase))
self.__storage = storage
def get(self, dreamId):
name = self.__storage.path(dreamId)
fullpath=os.path.join(self.__pathRoot, name)
return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath))
def delete(self, dreamId):
result = self.__storage.delete(dreamId)
return Response(status=204) if result else Response(status=404)
class ApiImagesMetadata(MethodView):
init_every_request = False
__pathRoot = None
__storage: ImageStorageService
@inject
def __init__(self, pathBase, storage: ImageStorageService = Provide[Container.image_storage_service]):
self.__pathRoot = os.path.abspath(os.path.join(os.path.dirname(__file__), pathBase))
self.__storage = storage
def get(self, dreamId):
meta = self.__storage.getMetadata(dreamId)
j = {} if meta is None else meta.__dict__
return j
class ApiIntermediates(MethodView):
init_every_request = False
__pathRoot = None
__storage: ImageStorageService = Provide[Container.image_intermediates_storage_service]
@inject
def __init__(self, pathBase, storage: ImageStorageService = Provide[Container.image_intermediates_storage_service]):
self.__pathRoot = os.path.abspath(os.path.join(os.path.dirname(__file__), pathBase))
self.__storage = storage
def get(self, dreamId, step):
name = self.__storage.path(dreamId, postfix=f'.{step}')
fullpath=os.path.join(self.__pathRoot, name)
return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath))
def delete(self, dreamId):
result = self.__storage.delete(dreamId)
return Response(status=204) if result else Response(status=404)
class ApiImagesList(MethodView):
init_every_request = False
__storage: ImageStorageService
@inject
def __init__(self, storage: ImageStorageService = Provide[Container.image_storage_service]):
self.__storage = storage
def get(self):
page = request.args.get("page", default=0, type=int)
perPage = request.args.get("per_page", default=10, type=int)
result = self.__storage.list_files(page, perPage)
return result.__dict__