mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
delete old 'server' package and the dependency_injector requirement (#2032)
fixes #1944
This commit is contained in:
parent
ffa54f4a35
commit
7d09d9da49
@ -30,7 +30,6 @@ dependencies:
|
|||||||
- torchvision
|
- torchvision
|
||||||
- transformers=4.21.3
|
- transformers=4.21.3
|
||||||
- pip:
|
- pip:
|
||||||
- dependency_injector==4.40.0
|
|
||||||
- getpass_asterisk
|
- getpass_asterisk
|
||||||
- omegaconf==2.1.1
|
- omegaconf==2.1.1
|
||||||
- picklescan
|
- picklescan
|
||||||
|
@ -10,7 +10,6 @@ dependencies:
|
|||||||
- pip:
|
- pip:
|
||||||
- --extra-index-url https://download.pytorch.org/whl/rocm5.2/
|
- --extra-index-url https://download.pytorch.org/whl/rocm5.2/
|
||||||
- albumentations==0.4.3
|
- albumentations==0.4.3
|
||||||
- dependency_injector==4.40.0
|
|
||||||
- diffusers==0.6.0
|
- diffusers==0.6.0
|
||||||
- einops==0.3.0
|
- einops==0.3.0
|
||||||
- eventlet
|
- eventlet
|
||||||
|
@ -13,7 +13,6 @@ dependencies:
|
|||||||
- cudatoolkit=11.6
|
- cudatoolkit=11.6
|
||||||
- pip:
|
- pip:
|
||||||
- albumentations==0.4.3
|
- albumentations==0.4.3
|
||||||
- dependency_injector==4.40.0
|
|
||||||
- diffusers==0.6.0
|
- diffusers==0.6.0
|
||||||
- einops==0.3.0
|
- einops==0.3.0
|
||||||
- eventlet
|
- eventlet
|
||||||
|
@ -13,7 +13,6 @@ dependencies:
|
|||||||
- cudatoolkit=11.6
|
- cudatoolkit=11.6
|
||||||
- pip:
|
- pip:
|
||||||
- albumentations==0.4.3
|
- albumentations==0.4.3
|
||||||
- dependency_injector==4.40.0
|
|
||||||
- diffusers==0.6.0
|
- diffusers==0.6.0
|
||||||
- einops==0.3.0
|
- einops==0.3.0
|
||||||
- eventlet
|
- eventlet
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# pip will resolve the version which matches torch
|
# pip will resolve the version which matches torch
|
||||||
albumentations
|
albumentations
|
||||||
dependency_injector==4.40.0
|
|
||||||
diffusers==0.10.*
|
diffusers==0.10.*
|
||||||
einops
|
einops
|
||||||
eventlet
|
eventlet
|
||||||
|
@ -1,152 +0,0 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
"""Application module."""
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from flask import Flask
|
|
||||||
from flask_cors import CORS
|
|
||||||
from flask_socketio import SocketIO
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from dependency_injector.wiring import inject, Provide
|
|
||||||
from ldm.invoke.args import Args
|
|
||||||
from server import views
|
|
||||||
from server.containers import Container
|
|
||||||
from server.services import GeneratorService, SignalService
|
|
||||||
|
|
||||||
# The socketio_service is injected here (rather than created in run_app) to initialize it
|
|
||||||
@inject
|
|
||||||
def initialize_app(
|
|
||||||
app: Flask,
|
|
||||||
socketio: SocketIO = Provide[Container.socketio]
|
|
||||||
) -> SocketIO:
|
|
||||||
socketio.init_app(app)
|
|
||||||
|
|
||||||
return socketio
|
|
||||||
|
|
||||||
# The signal and generator services are injected to warm up the processing queues
|
|
||||||
# TODO: Initialize these a better way?
|
|
||||||
@inject
|
|
||||||
def initialize_generator(
|
|
||||||
signal_service: SignalService = Provide[Container.signal_service],
|
|
||||||
generator_service: GeneratorService = Provide[Container.generator_service]
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def run_app(config, host, port) -> Flask:
|
|
||||||
app = Flask(__name__, static_url_path='')
|
|
||||||
|
|
||||||
# Set up dependency injection container
|
|
||||||
container = Container()
|
|
||||||
container.config.from_dict(config)
|
|
||||||
container.wire(modules=[__name__])
|
|
||||||
app.container = container
|
|
||||||
|
|
||||||
# Set up CORS
|
|
||||||
CORS(app, resources={r'/api/*': {'origins': '*'}})
|
|
||||||
|
|
||||||
# Web Routes
|
|
||||||
app.add_url_rule('/', view_func=views.WebIndex.as_view('web_index', 'index.html'))
|
|
||||||
app.add_url_rule('/index.css', view_func=views.WebIndex.as_view('web_index_css', 'index.css'))
|
|
||||||
app.add_url_rule('/index.js', view_func=views.WebIndex.as_view('web_index_js', 'index.js'))
|
|
||||||
app.add_url_rule('/config.js', view_func=views.WebConfig.as_view('web_config'))
|
|
||||||
|
|
||||||
# API Routes
|
|
||||||
app.add_url_rule('/api/jobs', view_func=views.ApiJobs.as_view('api_jobs'))
|
|
||||||
app.add_url_rule('/api/cancel', view_func=views.ApiCancel.as_view('api_cancel'))
|
|
||||||
|
|
||||||
# TODO: Get storage root from config
|
|
||||||
app.add_url_rule('/api/images/<string:dreamId>', view_func=views.ApiImages.as_view('api_images', '../'))
|
|
||||||
app.add_url_rule('/api/images/<string:dreamId>/metadata', view_func=views.ApiImagesMetadata.as_view('api_images_metadata', '../'))
|
|
||||||
app.add_url_rule('/api/images', view_func=views.ApiImagesList.as_view('api_images_list'))
|
|
||||||
app.add_url_rule('/api/intermediates/<string:dreamId>/<string:step>', view_func=views.ApiIntermediates.as_view('api_intermediates', '../'))
|
|
||||||
|
|
||||||
app.static_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../static/dream_web/'))
|
|
||||||
|
|
||||||
# Initialize
|
|
||||||
socketio = initialize_app(app)
|
|
||||||
initialize_generator()
|
|
||||||
|
|
||||||
print(">> Started Stable Diffusion api server!")
|
|
||||||
if host == '0.0.0.0':
|
|
||||||
print(f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
|
|
||||||
else:
|
|
||||||
print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
|
|
||||||
print(f">> Point your browser at http://{host}:{port}.")
|
|
||||||
|
|
||||||
# Run the app
|
|
||||||
socketio.run(app, host, port)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Initialize command-line parsers and the diffusion model"""
|
|
||||||
arg_parser = Args()
|
|
||||||
opt = arg_parser.parse_args()
|
|
||||||
|
|
||||||
if opt.laion400m:
|
|
||||||
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
|
|
||||||
sys.exit(-1)
|
|
||||||
if opt.weights:
|
|
||||||
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# models = OmegaConf.load(opt.config)
|
|
||||||
# width = models[opt.model].width
|
|
||||||
# height = models[opt.model].height
|
|
||||||
# config = models[opt.model].config
|
|
||||||
# weights = models[opt.model].weights
|
|
||||||
# except (FileNotFoundError, IOError, KeyError) as e:
|
|
||||||
# print(f'{e}. Aborting.')
|
|
||||||
# sys.exit(-1)
|
|
||||||
|
|
||||||
#print('* Initializing, be patient...\n')
|
|
||||||
sys.path.append('.')
|
|
||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
|
||||||
# when the frozen CLIP tokenizer is imported
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
|
|
||||||
appConfig = opt.__dict__
|
|
||||||
|
|
||||||
# appConfig = {
|
|
||||||
# "model": {
|
|
||||||
# "width": width,
|
|
||||||
# "height": height,
|
|
||||||
# "sampler_name": opt.sampler_name,
|
|
||||||
# "weights": weights,
|
|
||||||
# "precision": opt.precision,
|
|
||||||
# "config": config,
|
|
||||||
# "grid": opt.grid,
|
|
||||||
# "latent_diffusion_weights": opt.laion400m,
|
|
||||||
# "embedding_path": opt.embedding_path
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
|
|
||||||
# make sure the output directory exists
|
|
||||||
if not os.path.exists(opt.outdir):
|
|
||||||
os.makedirs(opt.outdir)
|
|
||||||
|
|
||||||
# gets rid of annoying messages about random seed
|
|
||||||
from pytorch_lightning import logging
|
|
||||||
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
|
||||||
|
|
||||||
print('\n* starting api server...')
|
|
||||||
# Change working directory to the stable-diffusion directory
|
|
||||||
os.chdir(
|
|
||||||
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start server
|
|
||||||
try:
|
|
||||||
run_app(appConfig, opt.host, opt.port)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,81 +0,0 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
"""Containers module."""
|
|
||||||
|
|
||||||
from dependency_injector import containers, providers
|
|
||||||
from flask_socketio import SocketIO
|
|
||||||
from ldm.generate import Generate
|
|
||||||
from server import services
|
|
||||||
|
|
||||||
class Container(containers.DeclarativeContainer):
|
|
||||||
wiring_config = containers.WiringConfiguration(packages=['server'])
|
|
||||||
|
|
||||||
config = providers.Configuration()
|
|
||||||
|
|
||||||
socketio = providers.ThreadSafeSingleton(
|
|
||||||
SocketIO,
|
|
||||||
app = None
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Add a model provider service that provides model(s) dynamically
|
|
||||||
model_singleton = providers.ThreadSafeSingleton(
|
|
||||||
Generate,
|
|
||||||
model = config.model,
|
|
||||||
sampler_name = config.sampler_name,
|
|
||||||
embedding_path = config.embedding_path,
|
|
||||||
precision = config.precision
|
|
||||||
# config = config.model.config,
|
|
||||||
|
|
||||||
# width = config.model.width,
|
|
||||||
# height = config.model.height,
|
|
||||||
# sampler_name = config.model.sampler_name,
|
|
||||||
# weights = config.model.weights,
|
|
||||||
# precision = config.model.precision,
|
|
||||||
# grid = config.model.grid,
|
|
||||||
# seamless = config.model.seamless,
|
|
||||||
# embedding_path = config.model.embedding_path,
|
|
||||||
# device_type = config.model.device_type
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: get location from config
|
|
||||||
image_storage_service = providers.ThreadSafeSingleton(
|
|
||||||
services.ImageStorageService,
|
|
||||||
'./outputs/img-samples/'
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: get location from config
|
|
||||||
image_intermediates_storage_service = providers.ThreadSafeSingleton(
|
|
||||||
services.ImageStorageService,
|
|
||||||
'./outputs/intermediates/'
|
|
||||||
)
|
|
||||||
|
|
||||||
signal_queue_service = providers.ThreadSafeSingleton(
|
|
||||||
services.SignalQueueService
|
|
||||||
)
|
|
||||||
|
|
||||||
signal_service = providers.ThreadSafeSingleton(
|
|
||||||
services.SignalService,
|
|
||||||
socketio = socketio,
|
|
||||||
queue = signal_queue_service
|
|
||||||
)
|
|
||||||
|
|
||||||
generation_queue_service = providers.ThreadSafeSingleton(
|
|
||||||
services.JobQueueService
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: get locations from config
|
|
||||||
log_service = providers.ThreadSafeSingleton(
|
|
||||||
services.LogService,
|
|
||||||
'./outputs/img-samples/',
|
|
||||||
'dream_web_log.txt'
|
|
||||||
)
|
|
||||||
|
|
||||||
generator_service = providers.ThreadSafeSingleton(
|
|
||||||
services.GeneratorService,
|
|
||||||
model = model_singleton,
|
|
||||||
queue = generation_queue_service,
|
|
||||||
imageStorage = image_storage_service,
|
|
||||||
intermediateStorage = image_intermediates_storage_service,
|
|
||||||
log = log_service,
|
|
||||||
signal_service = signal_service
|
|
||||||
)
|
|
259
server/models.py
259
server/models.py
@ -1,259 +0,0 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from base64 import urlsafe_b64encode
|
|
||||||
import json
|
|
||||||
import string
|
|
||||||
from copy import deepcopy
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, List, Union
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
|
|
||||||
class DreamBase():
|
|
||||||
# Id
|
|
||||||
id: str
|
|
||||||
|
|
||||||
# Initial Image
|
|
||||||
enable_init_image: bool
|
|
||||||
initimg: string = None
|
|
||||||
|
|
||||||
# Img2Img
|
|
||||||
enable_img2img: bool # TODO: support this better
|
|
||||||
strength: float = 0 # TODO: name this something related to img2img to make it clearer?
|
|
||||||
fit = None # Fit initial image dimensions
|
|
||||||
|
|
||||||
# Generation
|
|
||||||
enable_generate: bool
|
|
||||||
prompt: string = ""
|
|
||||||
seed: int = 0 # 0 is random
|
|
||||||
steps: int = 10
|
|
||||||
width: int = 512
|
|
||||||
height: int = 512
|
|
||||||
cfg_scale: float = 7.5
|
|
||||||
threshold: float = 0.0
|
|
||||||
perlin: float = 0.0
|
|
||||||
sampler_name: string = 'klms'
|
|
||||||
seamless: bool = False
|
|
||||||
hires_fix: bool = False
|
|
||||||
model: str = None # The model to use (currently unused)
|
|
||||||
embeddings = None # The embeddings to use (currently unused)
|
|
||||||
progress_images: bool = False
|
|
||||||
progress_latents: bool = False
|
|
||||||
|
|
||||||
# GFPGAN
|
|
||||||
enable_gfpgan: bool
|
|
||||||
facetool_strength: float = 0
|
|
||||||
|
|
||||||
# Upscale
|
|
||||||
enable_upscale: bool
|
|
||||||
upscale: None
|
|
||||||
upscale_level: int = None
|
|
||||||
upscale_strength: float = 0.75
|
|
||||||
|
|
||||||
# Embiggen
|
|
||||||
enable_embiggen: bool
|
|
||||||
embiggen: Union[None, List[float]] = None
|
|
||||||
embiggen_tiles: Union[None, List[int]] = None
|
|
||||||
|
|
||||||
# Metadata
|
|
||||||
time: int
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.id = urlsafe_b64encode(uuid4().bytes).decode('ascii')
|
|
||||||
|
|
||||||
def parse_json(self, j, new_instance=False):
|
|
||||||
# Id
|
|
||||||
if 'id' in j and not new_instance:
|
|
||||||
self.id = j.get('id')
|
|
||||||
|
|
||||||
# Initial Image
|
|
||||||
self.enable_init_image = 'enable_init_image' in j and bool(j.get('enable_init_image'))
|
|
||||||
if self.enable_init_image:
|
|
||||||
self.initimg = j.get('initimg')
|
|
||||||
|
|
||||||
# Img2Img
|
|
||||||
self.enable_img2img = 'enable_img2img' in j and bool(j.get('enable_img2img'))
|
|
||||||
if self.enable_img2img:
|
|
||||||
self.strength = float(j.get('strength'))
|
|
||||||
self.fit = 'fit' in j
|
|
||||||
|
|
||||||
# Generation
|
|
||||||
self.enable_generate = 'enable_generate' in j and bool(j.get('enable_generate'))
|
|
||||||
if self.enable_generate:
|
|
||||||
self.prompt = j.get('prompt')
|
|
||||||
self.seed = int(j.get('seed'))
|
|
||||||
self.steps = int(j.get('steps'))
|
|
||||||
self.width = int(j.get('width'))
|
|
||||||
self.height = int(j.get('height'))
|
|
||||||
self.cfg_scale = float(j.get('cfgscale') or j.get('cfg_scale'))
|
|
||||||
self.threshold = float(j.get('threshold'))
|
|
||||||
self.perlin = float(j.get('perlin'))
|
|
||||||
self.sampler_name = j.get('sampler') or j.get('sampler_name')
|
|
||||||
# model: str = None # The model to use (currently unused)
|
|
||||||
# embeddings = None # The embeddings to use (currently unused)
|
|
||||||
self.seamless = 'seamless' in j
|
|
||||||
self.hires_fix = 'hires_fix' in j
|
|
||||||
self.progress_images = 'progress_images' in j
|
|
||||||
self.progress_latents = 'progress_latents' in j
|
|
||||||
|
|
||||||
# GFPGAN
|
|
||||||
self.enable_gfpgan = 'enable_gfpgan' in j and bool(j.get('enable_gfpgan'))
|
|
||||||
if self.enable_gfpgan:
|
|
||||||
self.facetool_strength = float(j.get('facetool_strength'))
|
|
||||||
|
|
||||||
# Upscale
|
|
||||||
self.enable_upscale = 'enable_upscale' in j and bool(j.get('enable_upscale'))
|
|
||||||
if self.enable_upscale:
|
|
||||||
self.upscale_level = j.get('upscale_level')
|
|
||||||
self.upscale_strength = j.get('upscale_strength')
|
|
||||||
self.upscale = None if self.upscale_level in {None,''} else [int(self.upscale_level),float(self.upscale_strength)]
|
|
||||||
|
|
||||||
# Embiggen
|
|
||||||
self.enable_embiggen = 'enable_embiggen' in j and bool(j.get('enable_embiggen'))
|
|
||||||
if self.enable_embiggen:
|
|
||||||
self.embiggen = j.get('embiggen')
|
|
||||||
self.embiggen_tiles = j.get('embiggen_tiles')
|
|
||||||
|
|
||||||
# Metadata
|
|
||||||
self.time = int(j.get('time')) if ('time' in j and not new_instance) else int(datetime.now(timezone.utc).timestamp())
|
|
||||||
|
|
||||||
|
|
||||||
class DreamResult(DreamBase):
|
|
||||||
# Result
|
|
||||||
has_upscaled: False
|
|
||||||
has_gfpgan: False
|
|
||||||
|
|
||||||
# TODO: use something else for state tracking
|
|
||||||
images_generated: int = 0
|
|
||||||
images_upscaled: int = 0
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def clone_without_img(self):
|
|
||||||
copy = deepcopy(self)
|
|
||||||
copy.initimg = None
|
|
||||||
return copy
|
|
||||||
|
|
||||||
def to_json(self):
|
|
||||||
copy = deepcopy(self)
|
|
||||||
copy.initimg = None
|
|
||||||
j = json.dumps(copy.__dict__)
|
|
||||||
return j
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_json(j, newTime: bool = False):
|
|
||||||
d = DreamResult()
|
|
||||||
d.parse_json(j)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: switch this to a pipelined request, with pluggable steps
|
|
||||||
# Will likely require generator code changes to accomplish
|
|
||||||
class JobRequest(DreamBase):
|
|
||||||
# Iteration
|
|
||||||
iterations: int = 1
|
|
||||||
variation_amount = None
|
|
||||||
with_variations = None
|
|
||||||
|
|
||||||
# Results
|
|
||||||
results: List[DreamResult] = []
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def newDreamResult(self) -> DreamResult:
|
|
||||||
result = DreamResult()
|
|
||||||
result.parse_json(self.__dict__, new_instance=True)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_json(j):
|
|
||||||
job = JobRequest()
|
|
||||||
job.parse_json(j)
|
|
||||||
|
|
||||||
# Metadata
|
|
||||||
job.time = int(j.get('time')) if ('time' in j) else int(datetime.now(timezone.utc).timestamp())
|
|
||||||
|
|
||||||
# Iteration
|
|
||||||
if job.enable_generate:
|
|
||||||
job.iterations = int(j.get('iterations'))
|
|
||||||
job.variation_amount = float(j.get('variation_amount'))
|
|
||||||
job.with_variations = j.get('with_variations')
|
|
||||||
|
|
||||||
return job
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressType(Enum):
|
|
||||||
GENERATION = 1
|
|
||||||
UPSCALING_STARTED = 2
|
|
||||||
UPSCALING_DONE = 3
|
|
||||||
|
|
||||||
class Signal():
|
|
||||||
event: str
|
|
||||||
data = None
|
|
||||||
room: str = None
|
|
||||||
broadcast: bool = False
|
|
||||||
|
|
||||||
def __init__(self, event: str, data, room: str = None, broadcast: bool = False):
|
|
||||||
self.event = event
|
|
||||||
self.data = data
|
|
||||||
self.room = room
|
|
||||||
self.broadcast = broadcast
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def image_progress(jobId: str, dreamId: str, step: int, totalSteps: int, progressType: ProgressType = ProgressType.GENERATION, hasProgressImage: bool = False):
|
|
||||||
return Signal('dream_progress', {
|
|
||||||
'jobId': jobId,
|
|
||||||
'dreamId': dreamId,
|
|
||||||
'step': step,
|
|
||||||
'totalSteps': totalSteps,
|
|
||||||
'hasProgressImage': hasProgressImage,
|
|
||||||
'progressType': progressType.name
|
|
||||||
}, room=jobId, broadcast=True)
|
|
||||||
|
|
||||||
# TODO: use a result id or something? Like a sub-job
|
|
||||||
@staticmethod
|
|
||||||
def image_result(jobId: str, dreamId: str, dreamResult: DreamResult):
|
|
||||||
return Signal('dream_result', {
|
|
||||||
'jobId': jobId,
|
|
||||||
'dreamId': dreamId,
|
|
||||||
'dreamRequest': dreamResult.clone_without_img().__dict__
|
|
||||||
}, room=jobId, broadcast=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def job_started(jobId: str):
|
|
||||||
return Signal('job_started', {
|
|
||||||
'jobId': jobId
|
|
||||||
}, room=jobId, broadcast=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def job_done(jobId: str):
|
|
||||||
return Signal('job_done', {
|
|
||||||
'jobId': jobId
|
|
||||||
}, room=jobId, broadcast=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def job_canceled(jobId: str):
|
|
||||||
return Signal('job_canceled', {
|
|
||||||
'jobId': jobId
|
|
||||||
}, room=jobId, broadcast=True)
|
|
||||||
|
|
||||||
|
|
||||||
class PaginatedItems():
|
|
||||||
items: List[Any]
|
|
||||||
page: int # Current Page
|
|
||||||
pages: int # Total number of pages
|
|
||||||
per_page: int # Number of items per page
|
|
||||||
total: int # Total number of items in result
|
|
||||||
|
|
||||||
def __init__(self, items: List[Any], page: int, pages: int, per_page: int, total: int):
|
|
||||||
self.items = items
|
|
||||||
self.page = page
|
|
||||||
self.pages = pages
|
|
||||||
self.per_page = per_page
|
|
||||||
self.total = total
|
|
||||||
|
|
||||||
def to_json(self):
|
|
||||||
return json.dumps(self.__dict__)
|
|
@ -1,392 +0,0 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
import base64
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
import glob
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from queue import Empty, Queue
|
|
||||||
import shlex
|
|
||||||
from threading import Thread
|
|
||||||
import time
|
|
||||||
from flask_socketio import SocketIO, join_room, leave_room
|
|
||||||
from ldm.invoke.args import Args
|
|
||||||
from ldm.invoke.generator import embiggen
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ldm.invoke.pngwriter import PngWriter
|
|
||||||
from ldm.invoke.server import CanceledException
|
|
||||||
from ldm.generate import Generate
|
|
||||||
from server.models import DreamResult, JobRequest, PaginatedItems, ProgressType, Signal
|
|
||||||
|
|
||||||
class JobQueueService:
|
|
||||||
__queue: Queue = Queue()
|
|
||||||
|
|
||||||
def push(self, dreamRequest: DreamResult):
|
|
||||||
self.__queue.put(dreamRequest)
|
|
||||||
|
|
||||||
def get(self, timeout: float = None) -> DreamResult:
|
|
||||||
return self.__queue.get(timeout= timeout)
|
|
||||||
|
|
||||||
class SignalQueueService:
|
|
||||||
__queue: Queue = Queue()
|
|
||||||
|
|
||||||
def push(self, signal: Signal):
|
|
||||||
self.__queue.put(signal)
|
|
||||||
|
|
||||||
def get(self) -> Signal:
|
|
||||||
return self.__queue.get(block=False)
|
|
||||||
|
|
||||||
|
|
||||||
class SignalService:
|
|
||||||
__socketio: SocketIO
|
|
||||||
__queue: SignalQueueService
|
|
||||||
|
|
||||||
def __init__(self, socketio: SocketIO, queue: SignalQueueService):
|
|
||||||
self.__socketio = socketio
|
|
||||||
self.__queue = queue
|
|
||||||
|
|
||||||
def on_join(data):
|
|
||||||
room = data['room']
|
|
||||||
join_room(room)
|
|
||||||
self.__socketio.emit("test", "something", room=room)
|
|
||||||
|
|
||||||
def on_leave(data):
|
|
||||||
room = data['room']
|
|
||||||
leave_room(room)
|
|
||||||
|
|
||||||
self.__socketio.on_event('join_room', on_join)
|
|
||||||
self.__socketio.on_event('leave_room', on_leave)
|
|
||||||
|
|
||||||
self.__socketio.start_background_task(self.__process)
|
|
||||||
|
|
||||||
def __process(self):
|
|
||||||
# preload the model
|
|
||||||
print('Started signal queue processor')
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
signal = self.__queue.get()
|
|
||||||
self.__socketio.emit(signal.event, signal.data, room=signal.room, broadcast=signal.broadcast)
|
|
||||||
except Empty:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
self.__socketio.sleep(0.001)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print('Signal queue processor stopped')
|
|
||||||
|
|
||||||
|
|
||||||
def emit(self, signal: Signal):
|
|
||||||
self.__queue.push(signal)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Name this better?
|
|
||||||
# TODO: Logging and signals should probably be event based (multiple listeners for an event)
|
|
||||||
class LogService:
|
|
||||||
__location: str
|
|
||||||
__logFile: str
|
|
||||||
|
|
||||||
def __init__(self, location:str, file:str):
|
|
||||||
self.__location = location
|
|
||||||
self.__logFile = file
|
|
||||||
|
|
||||||
def log(self, dreamResult: DreamResult, seed = None, upscaled = False):
|
|
||||||
with open(os.path.join(self.__location, self.__logFile), "a") as log:
|
|
||||||
log.write(f"{dreamResult.id}: {dreamResult.to_json()}\n")
|
|
||||||
|
|
||||||
|
|
||||||
class ImageStorageService:
|
|
||||||
__location: str
|
|
||||||
__pngWriter: PngWriter
|
|
||||||
__legacyParser: ArgumentParser
|
|
||||||
|
|
||||||
def __init__(self, location):
|
|
||||||
self.__location = location
|
|
||||||
self.__pngWriter = PngWriter(self.__location)
|
|
||||||
self.__legacyParser = Args() # TODO: inject this?
|
|
||||||
|
|
||||||
def __getName(self, dreamId: str, postfix: str = '') -> str:
|
|
||||||
return f'{dreamId}{postfix}.png'
|
|
||||||
|
|
||||||
def save(self, image, dreamResult: DreamResult, postfix: str = '') -> str:
|
|
||||||
name = self.__getName(dreamResult.id, postfix)
|
|
||||||
meta = dreamResult.to_json() # TODO: make all methods consistent with writing metadata. Standardize metadata.
|
|
||||||
path = self.__pngWriter.save_image_and_prompt_to_png(image, dream_prompt=meta, metadata=None, name=name)
|
|
||||||
return path
|
|
||||||
|
|
||||||
def path(self, dreamId: str, postfix: str = '') -> str:
|
|
||||||
name = self.__getName(dreamId, postfix)
|
|
||||||
path = os.path.join(self.__location, name)
|
|
||||||
return path
|
|
||||||
|
|
||||||
# Returns true if found, false if not found or error
|
|
||||||
def delete(self, dreamId: str, postfix: str = '') -> bool:
|
|
||||||
path = self.path(dreamId, postfix)
|
|
||||||
if (os.path.exists(path)):
|
|
||||||
os.remove(path)
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def getMetadata(self, dreamId: str, postfix: str = '') -> DreamResult:
|
|
||||||
path = self.path(dreamId, postfix)
|
|
||||||
image = Image.open(path)
|
|
||||||
text = image.text
|
|
||||||
if text.__contains__('Dream'):
|
|
||||||
dreamMeta = text.get('Dream')
|
|
||||||
try:
|
|
||||||
j = json.loads(dreamMeta)
|
|
||||||
return DreamResult.from_json(j)
|
|
||||||
except ValueError:
|
|
||||||
# Try to parse command-line format (legacy metadata format)
|
|
||||||
try:
|
|
||||||
opt = self.__parseLegacyMetadata(dreamMeta)
|
|
||||||
optd = opt.__dict__
|
|
||||||
if (not 'width' in optd) or (optd.get('width') is None):
|
|
||||||
optd['width'] = image.width
|
|
||||||
if (not 'height' in optd) or (optd.get('height') is None):
|
|
||||||
optd['height'] = image.height
|
|
||||||
if (not 'steps' in optd) or (optd.get('steps') is None):
|
|
||||||
optd['steps'] = 10 # No way around this unfortunately - seems like it wasn't storing this previously
|
|
||||||
|
|
||||||
optd['time'] = os.path.getmtime(path) # Set timestamp manually (won't be exactly correct though)
|
|
||||||
|
|
||||||
return DreamResult.from_json(optd)
|
|
||||||
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __parseLegacyMetadata(self, command: str) -> DreamResult:
|
|
||||||
# before splitting, escape single quotes so as not to mess
|
|
||||||
# up the parser
|
|
||||||
command = command.replace("'", "\\'")
|
|
||||||
|
|
||||||
try:
|
|
||||||
elements = shlex.split(command)
|
|
||||||
except ValueError as e:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# rearrange the arguments to mimic how it works in the Dream bot.
|
|
||||||
switches = ['']
|
|
||||||
switches_started = False
|
|
||||||
|
|
||||||
for el in elements:
|
|
||||||
if el[0] == '-' and not switches_started:
|
|
||||||
switches_started = True
|
|
||||||
if switches_started:
|
|
||||||
switches.append(el)
|
|
||||||
else:
|
|
||||||
switches[0] += el
|
|
||||||
switches[0] += ' '
|
|
||||||
switches[0] = switches[0][: len(switches[0]) - 1]
|
|
||||||
|
|
||||||
try:
|
|
||||||
opt = self.__legacyParser.parse_cmd(switches)
|
|
||||||
return opt
|
|
||||||
except SystemExit:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def list_files(self, page: int, perPage: int) -> PaginatedItems:
|
|
||||||
files = sorted(glob.glob(os.path.join(self.__location,'*.png')), key=os.path.getmtime, reverse=True)
|
|
||||||
count = len(files)
|
|
||||||
|
|
||||||
startId = page * perPage
|
|
||||||
pageCount = int(count / perPage) + 1
|
|
||||||
endId = min(startId + perPage, count)
|
|
||||||
items = [] if startId >= count else files[startId:endId]
|
|
||||||
|
|
||||||
items = list(map(lambda f: Path(f).stem, items))
|
|
||||||
|
|
||||||
return PaginatedItems(items, page, pageCount, perPage, count)
|
|
||||||
|
|
||||||
|
|
||||||
class GeneratorService:
|
|
||||||
__model: Generate
|
|
||||||
__queue: JobQueueService
|
|
||||||
__imageStorage: ImageStorageService
|
|
||||||
__intermediateStorage: ImageStorageService
|
|
||||||
__log: LogService
|
|
||||||
__thread: Thread
|
|
||||||
__cancellationRequested: bool = False
|
|
||||||
__signal_service: SignalService
|
|
||||||
|
|
||||||
def __init__(self, model: Generate, queue: JobQueueService, imageStorage: ImageStorageService, intermediateStorage: ImageStorageService, log: LogService, signal_service: SignalService):
|
|
||||||
self.__model = model
|
|
||||||
self.__queue = queue
|
|
||||||
self.__imageStorage = imageStorage
|
|
||||||
self.__intermediateStorage = intermediateStorage
|
|
||||||
self.__log = log
|
|
||||||
self.__signal_service = signal_service
|
|
||||||
|
|
||||||
# Create the background thread
|
|
||||||
self.__thread = Thread(target=self.__process, name = "GeneratorService")
|
|
||||||
self.__thread.daemon = True
|
|
||||||
self.__thread.start()
|
|
||||||
|
|
||||||
|
|
||||||
# Request cancellation of the current job
|
|
||||||
def cancel(self):
|
|
||||||
self.__cancellationRequested = True
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Consider moving this to its own service if there's benefit in separating the generator
|
|
||||||
def __process(self):
|
|
||||||
# preload the model
|
|
||||||
# TODO: support multiple models
|
|
||||||
print('Preloading model')
|
|
||||||
tic = time.time()
|
|
||||||
self.__model.load_model()
|
|
||||||
print(f'>> model loaded in', '%4.2fs' % (time.time() - tic))
|
|
||||||
|
|
||||||
print('Started generation queue processor')
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
dreamRequest = self.__queue.get()
|
|
||||||
self.__generate(dreamRequest)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print('Generation queue processor stopped')
|
|
||||||
|
|
||||||
|
|
||||||
def __on_start(self, jobRequest: JobRequest):
|
|
||||||
self.__signal_service.emit(Signal.job_started(jobRequest.id))
|
|
||||||
|
|
||||||
|
|
||||||
def __on_image_result(self, jobRequest: JobRequest, image, seed, upscaled=False):
|
|
||||||
dreamResult = jobRequest.newDreamResult()
|
|
||||||
dreamResult.seed = seed
|
|
||||||
dreamResult.has_upscaled = upscaled
|
|
||||||
dreamResult.iterations = 1
|
|
||||||
jobRequest.results.append(dreamResult)
|
|
||||||
# TODO: Separate status of GFPGAN?
|
|
||||||
|
|
||||||
self.__imageStorage.save(image, dreamResult)
|
|
||||||
|
|
||||||
# TODO: handle upscaling logic better (this is appending data to log, but only on first generation)
|
|
||||||
if not upscaled:
|
|
||||||
self.__log.log(dreamResult)
|
|
||||||
|
|
||||||
# Send result signal
|
|
||||||
self.__signal_service.emit(Signal.image_result(jobRequest.id, dreamResult.id, dreamResult))
|
|
||||||
|
|
||||||
upscaling_requested = dreamResult.enable_upscale or dreamResult.enable_gfpgan
|
|
||||||
|
|
||||||
# Report upscaling status
|
|
||||||
# TODO: this is very coupled to logic inside the generator. Fix that.
|
|
||||||
if upscaling_requested and any(result.has_upscaled for result in jobRequest.results):
|
|
||||||
progressType = ProgressType.UPSCALING_STARTED if len(jobRequest.results) < 2 * jobRequest.iterations else ProgressType.UPSCALING_DONE
|
|
||||||
upscale_count = sum(1 for i in jobRequest.results if i.has_upscaled)
|
|
||||||
self.__signal_service.emit(Signal.image_progress(jobRequest.id, dreamResult.id, upscale_count, jobRequest.iterations, progressType))
|
|
||||||
|
|
||||||
|
|
||||||
def __on_progress(self, jobRequest: JobRequest, sample, step):
|
|
||||||
if self.__cancellationRequested:
|
|
||||||
self.__cancellationRequested = False
|
|
||||||
raise CanceledException
|
|
||||||
|
|
||||||
# TODO: Progress per request will be easier once the seeds (and ids) can all be pre-generated
|
|
||||||
hasProgressImage = False
|
|
||||||
s = str(len(jobRequest.results))
|
|
||||||
if jobRequest.progress_images and step % 5 == 0 and step < jobRequest.steps - 1:
|
|
||||||
image = self.__model._sample_to_image(sample)
|
|
||||||
|
|
||||||
# TODO: clean this up, use a pre-defined dream result
|
|
||||||
result = DreamResult()
|
|
||||||
result.parse_json(jobRequest.__dict__, new_instance=False)
|
|
||||||
self.__intermediateStorage.save(image, result, postfix=f'.{s}.{step}')
|
|
||||||
hasProgressImage = True
|
|
||||||
|
|
||||||
self.__signal_service.emit(Signal.image_progress(jobRequest.id, f'{jobRequest.id}.{s}', step, jobRequest.steps, ProgressType.GENERATION, hasProgressImage))
|
|
||||||
|
|
||||||
|
|
||||||
def __generate(self, jobRequest: JobRequest):
|
|
||||||
try:
|
|
||||||
# TODO: handle this file a file service for init images
|
|
||||||
initimgfile = None # TODO: support this on the model directly?
|
|
||||||
if (jobRequest.enable_init_image):
|
|
||||||
if jobRequest.initimg is not None:
|
|
||||||
with open("./img2img-tmp.png", "wb") as f:
|
|
||||||
initimg = jobRequest.initimg.split(",")[1] # Ignore mime type
|
|
||||||
f.write(base64.b64decode(initimg))
|
|
||||||
initimgfile = "./img2img-tmp.png"
|
|
||||||
|
|
||||||
# Use previous seed if set to -1
|
|
||||||
initSeed = jobRequest.seed
|
|
||||||
if initSeed == -1:
|
|
||||||
initSeed = self.__model.seed
|
|
||||||
|
|
||||||
# Zero gfpgan strength if the model doesn't exist
|
|
||||||
# TODO: determine if this could be at the top now? Used to cause circular import
|
|
||||||
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
|
||||||
if not gfpgan_model_exists:
|
|
||||||
jobRequest.enable_gfpgan = False
|
|
||||||
|
|
||||||
# Signal start
|
|
||||||
self.__on_start(jobRequest)
|
|
||||||
|
|
||||||
# Generate in model
|
|
||||||
# TODO: Split job generation requests instead of fitting all parameters here
|
|
||||||
# TODO: Support no generation (just upscaling/gfpgan)
|
|
||||||
|
|
||||||
upscale = None if not jobRequest.enable_upscale else jobRequest.upscale
|
|
||||||
facetool_strength = 0 if not jobRequest.enable_gfpgan else jobRequest.facetool_strength
|
|
||||||
|
|
||||||
if not jobRequest.enable_generate:
|
|
||||||
# If not generating, check if we're upscaling or running gfpgan
|
|
||||||
if not upscale and not facetool_strength:
|
|
||||||
# Invalid settings (TODO: Add message to help user)
|
|
||||||
raise CanceledException()
|
|
||||||
|
|
||||||
image = Image.open(initimgfile)
|
|
||||||
# TODO: support progress for upscale?
|
|
||||||
self.__model.upscale_and_reconstruct(
|
|
||||||
image_list = [[image,0]],
|
|
||||||
upscale = upscale,
|
|
||||||
strength = facetool_strength,
|
|
||||||
save_original = False,
|
|
||||||
image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled))
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Generating - run the generation
|
|
||||||
init_img = None if (not jobRequest.enable_img2img or jobRequest.strength == 0) else initimgfile
|
|
||||||
|
|
||||||
|
|
||||||
self.__model.prompt2image(
|
|
||||||
prompt = jobRequest.prompt,
|
|
||||||
init_img = init_img, # TODO: ensure this works
|
|
||||||
strength = None if init_img is None else jobRequest.strength,
|
|
||||||
fit = None if init_img is None else jobRequest.fit,
|
|
||||||
iterations = jobRequest.iterations,
|
|
||||||
cfg_scale = jobRequest.cfg_scale,
|
|
||||||
threshold = jobRequest.threshold,
|
|
||||||
perlin = jobRequest.perlin,
|
|
||||||
width = jobRequest.width,
|
|
||||||
height = jobRequest.height,
|
|
||||||
seed = jobRequest.seed,
|
|
||||||
steps = jobRequest.steps,
|
|
||||||
variation_amount = jobRequest.variation_amount,
|
|
||||||
with_variations = jobRequest.with_variations,
|
|
||||||
facetool_strength = facetool_strength,
|
|
||||||
upscale = upscale,
|
|
||||||
sampler_name = jobRequest.sampler_name,
|
|
||||||
seamless = jobRequest.seamless,
|
|
||||||
hires_fix = jobRequest.hires_fix,
|
|
||||||
embiggen = jobRequest.embiggen,
|
|
||||||
embiggen_tiles = jobRequest.embiggen_tiles,
|
|
||||||
step_callback = lambda sample, step: self.__on_progress(jobRequest, sample, step),
|
|
||||||
image_callback = lambda image, seed, upscaled=False: self.__on_image_result(jobRequest, image, seed, upscaled))
|
|
||||||
|
|
||||||
except CanceledException:
|
|
||||||
self.__signal_service.emit(Signal.job_canceled(jobRequest.id))
|
|
||||||
|
|
||||||
finally:
|
|
||||||
self.__signal_service.emit(Signal.job_done(jobRequest.id))
|
|
||||||
|
|
||||||
# Remove the temp file
|
|
||||||
if (initimgfile is not None):
|
|
||||||
os.remove("./img2img-tmp.png")
|
|
132
server/views.py
132
server/views.py
@ -1,132 +0,0 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
"""Views module."""
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from queue import Queue
|
|
||||||
from flask import current_app, jsonify, request, Response, send_from_directory, stream_with_context, url_for
|
|
||||||
from flask.views import MethodView
|
|
||||||
from dependency_injector.wiring import inject, Provide
|
|
||||||
|
|
||||||
from server.models import DreamResult, JobRequest
|
|
||||||
from server.services import GeneratorService, ImageStorageService, JobQueueService
|
|
||||||
from server.containers import Container
|
|
||||||
|
|
||||||
class ApiJobs(MethodView):
|
|
||||||
|
|
||||||
@inject
|
|
||||||
def post(self, job_queue_service: JobQueueService = Provide[Container.generation_queue_service]):
|
|
||||||
jobRequest = JobRequest.from_json(request.json)
|
|
||||||
|
|
||||||
print(f">> Request to generate with prompt: {jobRequest.prompt}")
|
|
||||||
|
|
||||||
# Push the request
|
|
||||||
job_queue_service.push(jobRequest)
|
|
||||||
|
|
||||||
return { 'jobId': jobRequest.id }
|
|
||||||
|
|
||||||
|
|
||||||
class WebIndex(MethodView):
|
|
||||||
init_every_request = False
|
|
||||||
__file: str = None
|
|
||||||
|
|
||||||
def __init__(self, file):
|
|
||||||
self.__file = file
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
return current_app.send_static_file(self.__file)
|
|
||||||
|
|
||||||
|
|
||||||
class WebConfig(MethodView):
|
|
||||||
init_every_request = False
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
# unfortunately this import can't be at the top level, since that would cause a circular import
|
|
||||||
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
|
||||||
config = {
|
|
||||||
'gfpgan_model_exists': gfpgan_model_exists
|
|
||||||
}
|
|
||||||
js = f"let config = {json.dumps(config)};\n"
|
|
||||||
return Response(js, mimetype="application/javascript")
|
|
||||||
|
|
||||||
|
|
||||||
class ApiCancel(MethodView):
|
|
||||||
init_every_request = False
|
|
||||||
|
|
||||||
@inject
|
|
||||||
def get(self, generator_service: GeneratorService = Provide[Container.generator_service]):
|
|
||||||
generator_service.cancel()
|
|
||||||
return Response(status=204)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Combine all image storage access
|
|
||||||
class ApiImages(MethodView):
|
|
||||||
init_every_request = False
|
|
||||||
__pathRoot = None
|
|
||||||
__storage: ImageStorageService
|
|
||||||
|
|
||||||
@inject
|
|
||||||
def __init__(self, pathBase, storage: ImageStorageService = Provide[Container.image_storage_service]):
|
|
||||||
self.__pathRoot = os.path.abspath(os.path.join(os.path.dirname(__file__), pathBase))
|
|
||||||
self.__storage = storage
|
|
||||||
|
|
||||||
def get(self, dreamId):
|
|
||||||
name = self.__storage.path(dreamId)
|
|
||||||
fullpath=os.path.join(self.__pathRoot, name)
|
|
||||||
return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath))
|
|
||||||
|
|
||||||
def delete(self, dreamId):
|
|
||||||
result = self.__storage.delete(dreamId)
|
|
||||||
return Response(status=204) if result else Response(status=404)
|
|
||||||
|
|
||||||
|
|
||||||
class ApiImagesMetadata(MethodView):
|
|
||||||
init_every_request = False
|
|
||||||
__pathRoot = None
|
|
||||||
__storage: ImageStorageService
|
|
||||||
|
|
||||||
@inject
|
|
||||||
def __init__(self, pathBase, storage: ImageStorageService = Provide[Container.image_storage_service]):
|
|
||||||
self.__pathRoot = os.path.abspath(os.path.join(os.path.dirname(__file__), pathBase))
|
|
||||||
self.__storage = storage
|
|
||||||
|
|
||||||
def get(self, dreamId):
|
|
||||||
meta = self.__storage.getMetadata(dreamId)
|
|
||||||
j = {} if meta is None else meta.__dict__
|
|
||||||
return j
|
|
||||||
|
|
||||||
|
|
||||||
class ApiIntermediates(MethodView):
|
|
||||||
init_every_request = False
|
|
||||||
__pathRoot = None
|
|
||||||
__storage: ImageStorageService = Provide[Container.image_intermediates_storage_service]
|
|
||||||
|
|
||||||
@inject
|
|
||||||
def __init__(self, pathBase, storage: ImageStorageService = Provide[Container.image_intermediates_storage_service]):
|
|
||||||
self.__pathRoot = os.path.abspath(os.path.join(os.path.dirname(__file__), pathBase))
|
|
||||||
self.__storage = storage
|
|
||||||
|
|
||||||
def get(self, dreamId, step):
|
|
||||||
name = self.__storage.path(dreamId, postfix=f'.{step}')
|
|
||||||
fullpath=os.path.join(self.__pathRoot, name)
|
|
||||||
return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath))
|
|
||||||
|
|
||||||
def delete(self, dreamId):
|
|
||||||
result = self.__storage.delete(dreamId)
|
|
||||||
return Response(status=204) if result else Response(status=404)
|
|
||||||
|
|
||||||
|
|
||||||
class ApiImagesList(MethodView):
|
|
||||||
init_every_request = False
|
|
||||||
__storage: ImageStorageService
|
|
||||||
|
|
||||||
@inject
|
|
||||||
def __init__(self, storage: ImageStorageService = Provide[Container.image_storage_service]):
|
|
||||||
self.__storage = storage
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
page = request.args.get("page", default=0, type=int)
|
|
||||||
perPage = request.args.get("per_page", default=10, type=int)
|
|
||||||
|
|
||||||
result = self.__storage.list_files(page, perPage)
|
|
||||||
return result.__dict__
|
|
Loading…
Reference in New Issue
Block a user