diff --git a/scripts/dream_web.py b/scripts/dream_web.py index 378845e89b..ec909ec5bd 100644 --- a/scripts/dream_web.py +++ b/scripts/dream_web.py @@ -1,5 +1,6 @@ import json import base64 +import mimetypes import os from pytorch_lightning import logging from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer @@ -25,11 +26,15 @@ class DreamServer(BaseHTTPRequestHandler): with open("./static/index.html", "rb") as content: self.wfile.write(content.read()) elif os.path.exists("." + self.path): - self.send_response(200) - self.send_header("Content-type", "image/png") - self.end_headers() - with open("." + self.path, "rb") as content: - self.wfile.write(content.read()) + mime_type = mimetypes.guess_type(self.path)[0] + if mime_type is not None: + self.send_response(200) + self.send_header("Content-type", mime_type) + self.end_headers() + with open("." + self.path, "rb") as content: + self.wfile.write(content.read()) + else: + self.send_response(404) else: self.send_response(404) @@ -85,6 +90,12 @@ class DreamServer(BaseHTTPRequestHandler): self.wfile.write(bytes(json.dumps(result), "utf-8")) if __name__ == "__main__": + # Change working directory to the stable-diffusion directory + os.chdir( + os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')) + ) + + # Start server dream_server = ThreadingHTTPServer(("0.0.0.0", 9090), DreamServer) print("\n\n* Started Stable Diffusion dream server! Point your browser at http://localhost:9090 or use the host's DNS name or IP address. *") diff --git a/scripts/static/index.css b/scripts/static/index.css new file mode 100644 index 0000000000..ed840a056a --- /dev/null +++ b/scripts/static/index.css @@ -0,0 +1,61 @@ +* { + font-family: 'Arial'; +} +#header { + text-decoration: dotted underline; +} +#search { + margin-top: 20vh; + margin-left: auto; + margin-right: auto; + max-width: 800px; + + text-align: center; +} +fieldset { + border: none; +} +#fieldset-search { + display: flex; +} +#prompt { + flex-grow: 1; + + border-radius: 20px 0px 0px 20px; + padding: 5px 10px 5px 10px; + border: 1px solid black; + border-right: none; + outline: none; +} +#submit { + border-radius: 0px 20px 20px 0px; + padding: 5px 10px 5px 10px; + border: 1px solid black; +} +#results { + text-align: center; + max-width: 1000px; + margin: auto; + padding-top: 10px; +} +img { + cursor: pointer; + height: 30vh; + border-radius: 5px; + margin: 10px; +} +#fieldset-config { + line-height:2em; +} +input[type="number"] { + width: 60px; +} +#seed { + width: 150px; +} +hr { + width: 200px; +} +label { + white-space: nowrap; +} diff --git a/scripts/static/index.js b/scripts/static/index.js new file mode 100644 index 0000000000..3b99deecf4 --- /dev/null +++ b/scripts/static/index.js @@ -0,0 +1,101 @@ +function toBase64(file) { + return new Promise((resolve, reject) => { + const r = new FileReader(); + r.readAsDataURL(file); + r.onload = () => resolve(r.result); + r.onerror = (error) => reject(error); + }); +} + +function appendOutput(output) { + let outputNode = document.createElement("img"); + outputNode.src = output[0]; + + let outputConfig = output[2]; + let altText = output[1].toString() + " | " + outputConfig.prompt; + outputNode.alt = altText; + outputNode.title = altText; + + // Reload image config + outputNode.addEventListener('click', () => { + let form = document.querySelector("#generate-form"); + for (const [k, v] of new FormData(form)) { + form.querySelector(`*[name=${k}]`).value = outputConfig[k]; + } + document.querySelector("#seed").value = output[1]; + + saveFields(document.querySelector("#generate-form")); + }); + + document.querySelector("#results").prepend(outputNode); +} + +function appendOutputs(outputs) { + for (const output of outputs) { + appendOutput(output); + } +} + +function saveFields(form) { + for (const [k, v] of new FormData(form)) { + if (typeof v !== 'object') { // Don't save 'file' type + localStorage.setItem(k, v); + } + } +} +function loadFields(form) { + for (const [k, v] of new FormData(form)) { + const item = localStorage.getItem(k); + if (item != null) { + form.querySelector(`*[name=${k}]`).value = item; + } + } +} + +async function generateSubmit(form) { + const prompt = document.querySelector("#prompt").value; + + // Convert file data to base64 + let formData = Object.fromEntries(new FormData(form)); + formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null; + + // Post as JSON + fetch(form.action, { + method: form.method, + body: JSON.stringify(formData), + }).then(async (result) => { + let data = await result.json(); + + // Re-enable form, remove no-results-message + form.querySelector('fieldset').removeAttribute('disabled'); + document.querySelector("#prompt").value = prompt; + + if (data.outputs.length != 0) { + document.querySelector("#no-results-message")?.remove(); + appendOutputs(data.outputs); + } else { + alert("Error occurred while generating."); + } + }); + + // Disable form while generating + form.querySelector('fieldset').setAttribute('disabled',''); + document.querySelector("#prompt").value = `Generating: "${prompt}"`; +} + +window.onload = () => { + document.querySelector("#generate-form").addEventListener('submit', (e) => { + e.preventDefault(); + const form = e.target; + + generateSubmit(form); + }); + document.querySelector("#generate-form").addEventListener('change', (e) => { + saveFields(e.target.form); + }); + document.querySelector("#reset").addEventListener('click', (e) => { + document.querySelector("#seed").value = -1; + saveFields(e.target.form); + }); + loadFields(document.querySelector("#generate-form")); +}; diff --git a/static/index.html b/static/index.html index a96405aafb..4e7af0f771 100644 --- a/static/index.html +++ b/static/index.html @@ -2,170 +2,23 @@
No results...