Add img2img support, fix naming conventions

This commit is contained in:
tesseractcat 2022-08-24 23:03:02 -04:00
parent 269fcf92d9
commit ab131cb55e
2 changed files with 105 additions and 57 deletions

View File

@ -1,4 +1,5 @@
import json import json
import base64
import os import os
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
@ -28,6 +29,7 @@ class DreamServer(BaseHTTPRequestHandler):
content_length = int(self.headers['Content-Length']) content_length = int(self.headers['Content-Length'])
post_data = json.loads(self.rfile.read(content_length)) post_data = json.loads(self.rfile.read(content_length))
prompt = post_data['prompt'] prompt = post_data['prompt']
initimg = post_data['initimg']
batch = int(post_data['batch']) batch = int(post_data['batch'])
steps = int(post_data['steps']) steps = int(post_data['steps'])
width = int(post_data['width']) width = int(post_data['width'])
@ -35,16 +37,37 @@ class DreamServer(BaseHTTPRequestHandler):
cfgscale = float(post_data['cfgscale']) cfgscale = float(post_data['cfgscale'])
seed = None if int(post_data['seed']) == -1 else int(post_data['seed']) seed = None if int(post_data['seed']) == -1 else int(post_data['seed'])
print(f"Request to generate with data: {post_data}") print(f"Request to generate with prompt: {prompt}")
outputs = []
if initimg is None:
# Run txt2img
outputs = model.txt2img(prompt, outputs = model.txt2img(prompt,
batch_size = batch, batch_size = batch,
cfg_scale = cfgscale, cfg_scale = cfgscale,
width = width, width = width,
height = height, height = height,
seed = seed, seed = seed,
steps = steps); steps = steps)
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
initimg = initimg.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))
# Run img2img
outputs = model.img2img(prompt,
init_img = "./img2img-tmp.png",
batch_size = batch,
cfg_scale = cfgscale,
seed = seed,
steps = steps)
# Remove the temp file
os.remove("./img2img-tmp.png")
print(f"Prompt generated with output: {outputs}") print(f"Prompt generated with output: {outputs}")
post_data['initimg'] = '' # Don't send init image back
outputs = [x + [post_data] for x in outputs] # Append config to each output outputs = [x + [post_data] for x in outputs] # Append config to each output
result = {'outputs': outputs} result = {'outputs': outputs}
self.wfile.write(bytes(json.dumps(result), "utf-8")) self.wfile.write(bytes(json.dumps(result), "utf-8"))

View File

@ -40,6 +40,9 @@
border-radius: 5px; border-radius: 5px;
margin: 10px; margin: 10px;
} }
#generate-config {
line-height:2em;
}
input[type="number"] { input[type="number"] {
width: 60px; width: 60px;
} }
@ -48,41 +51,52 @@
} }
</style> </style>
<script> <script>
function append_output(output) { function toBase64(file) {
let output_node = document.createElement("img"); return new Promise((resolve, reject) => {
output_node.src = output[0]; const r = new FileReader();
r.readAsDataURL(file);
r.onload = () => resolve(r.result);
r.onerror = (error) => reject(error);
});
}
let output_config = output[2]; function appendOutput(output) {
let alt_text = output[1].toString() + " | " + output_config.prompt; let outputNode = document.createElement("img");
output_node.alt = alt_text; outputNode.src = output[0];
output_node.title = alt_text;
let outputConfig = output[2];
let altText = output[1].toString() + " | " + outputConfig.prompt;
outputNode.alt = altText;
outputNode.title = altText;
// Reload image config // Reload image config
output_node.addEventListener('click', () => { outputNode.addEventListener('click', () => {
let form = document.querySelector("#generate_form"); let form = document.querySelector("#generate-form");
for (const [k, v] of new FormData(form)) { for (const [k, v] of new FormData(form)) {
form.querySelector(`*[name=${k}]`).value = output_config[k]; form.querySelector(`*[name=${k}]`).value = outputConfig[k];
} }
document.querySelector("#seed").value = output[1]; document.querySelector("#seed").value = output[1];
save_fields(document.querySelector("#generate_form")); saveFields(document.querySelector("#generate-form"));
}); });
document.querySelector("#results").prepend(output_node); document.querySelector("#results").prepend(outputNode);
} }
function append_outputs(outputs) { function appendOutputs(outputs) {
for (const output of outputs) { for (const output of outputs) {
append_output(output); appendOutput(output);
} }
} }
function save_fields(form) { function saveFields(form) {
for (const [k, v] of new FormData(form)) { for (const [k, v] of new FormData(form)) {
if (typeof v !== 'object') { // Don't save 'file' type
localStorage.setItem(k, v); localStorage.setItem(k, v);
} }
} }
function load_fields(form) { }
function loadFields(form) {
for (const [k, v] of new FormData(form)) { for (const [k, v] of new FormData(form)) {
const item = localStorage.getItem(k); const item = localStorage.getItem(k);
if (item != null) { if (item != null) {
@ -91,43 +105,52 @@
} }
} }
window.onload = () => { async function generateSubmit(form) {
document.querySelector("#generate_form").addEventListener('submit', (e) => {
e.preventDefault();
const form = e.target;
const prompt = document.querySelector("#prompt").value; const prompt = document.querySelector("#prompt").value;
// Convert file data to base64
let formData = Object.fromEntries(new FormData(form));
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
// Post as JSON // Post as JSON
fetch(form.action, { fetch(form.action, {
method: form.method, method: form.method,
body: JSON.stringify(Object.fromEntries(new FormData(form))), body: JSON.stringify(formData),
}).then((result) => { }).then(async (result) => {
result.json().then((data) => { let data = await result.json();
// Re-enable form, remove no-results-message // Re-enable form, remove no-results-message
form.querySelector('fieldset').removeAttribute('disabled'); form.querySelector('fieldset').removeAttribute('disabled');
document.querySelector("#prompt").value = prompt; document.querySelector("#prompt").value = prompt;
if (data.outputs.length != 0) { if (data.outputs.length != 0) {
document.querySelector("#no_results_message")?.remove(); document.querySelector("#no-results-message")?.remove();
append_outputs(data.outputs); appendOutputs(data.outputs);
} else { } else {
alert("Error occurred while generating."); alert("Error occurred while generating.");
} }
}); });
});
// Disable form // Disable form while generating
form.querySelector('fieldset').setAttribute('disabled',''); form.querySelector('fieldset').setAttribute('disabled','');
document.querySelector("#prompt").value = `Generating: "${prompt}"`; document.querySelector("#prompt").value = `Generating: "${prompt}"`;
}
window.onload = () => {
document.querySelector("#generate-form").addEventListener('submit', (e) => {
e.preventDefault();
const form = e.target;
generateSubmit(form);
}); });
document.querySelector("#generate_form").addEventListener('change', (e) => { document.querySelector("#generate-form").addEventListener('change', (e) => {
save_fields(e.target.form); saveFields(e.target.form);
}); });
document.querySelector("#reset").addEventListener('click', (e) => { document.querySelector("#reset").addEventListener('click', (e) => {
document.querySelector("#seed").value = -1; document.querySelector("#seed").value = -1;
save_fields(e.target.form); saveFields(e.target.form);
}); });
load_fields(document.querySelector("#generate_form")); loadFields(document.querySelector("#generate-form"));
}; };
</script> </script>
</head> </head>
@ -135,12 +158,12 @@
<div id="search"> <div id="search">
<h2 id="header">Stable Diffusion</h2> <h2 id="header">Stable Diffusion</h2>
<form id="generate_form" method="post" action="#"> <form id="generate-form" method="post" action="#">
<fieldset> <fieldset>
<input type="text" id="prompt" name="prompt"> <input type="text" id="prompt" name="prompt">
<input type="submit" id="submit" value="Generate"> <input type="submit" id="submit" value="Generate">
</fieldset> </fieldset>
<fieldset> <fieldset id="generate-config">
<label for="batch">Batch Size:</label> <label for="batch">Batch Size:</label>
<input value="1" type="number" id="batch" name="batch"> <input value="1" type="number" id="batch" name="batch">
<label for="steps">Steps:</label> <label for="steps">Steps:</label>
@ -152,7 +175,9 @@
<input value="512" type="number" id="width" name="width"> <input value="512" type="number" id="width" name="width">
<label title="Set to multiple of 64" for="height">Height:</label> <label title="Set to multiple of 64" for="height">Height:</label>
<input value="512" type="number" id="height" name="height"> <input value="512" type="number" id="height" name="height">
<span>&bull;</span> <br>
<label title="Upload an image to use img2img" for="initimg">Img2Img Init:</label>
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
<label title="Set to -1 for random seed" for="seed">Seed:</label> <label title="Set to -1 for random seed" for="seed">Seed:</label>
<input value="-1" type="number" id="seed" name="seed"> <input value="-1" type="number" id="seed" name="seed">
<button type="button" id="reset">&olarr;</button> <button type="button" id="reset">&olarr;</button>
@ -161,7 +186,7 @@
</div> </div>
<hr style="width: 200px"> <hr style="width: 200px">
<div id="results"> <div id="results">
<div id="no_results_message"> <div id="no-results-message">
<i><p>No results...</p></i> <i><p>No results...</p></i>
</div> </div>
</div> </div>