mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added customized patches and updated the README
This commit is contained in:
parent
d39f5b51a8
commit
d6124c44a3
91
README.md
91
README.md
@ -1,4 +1,95 @@
|
|||||||
# Stable Diffusion
|
# Stable Diffusion
|
||||||
|
|
||||||
|
This is a fork of CompVis/stable-diffusion, the wonderful open source
|
||||||
|
text-to-image generator.
|
||||||
|
|
||||||
|
The original has been modified in several minor ways:
|
||||||
|
|
||||||
|
## Simplified API for text to image generation
|
||||||
|
|
||||||
|
There is now a simplified API for text to image generation, which
|
||||||
|
lets you create images from a prompt in just three lines of code:
|
||||||
|
|
||||||
|
~~~~
|
||||||
|
from ldm.simplet2i import T2I
|
||||||
|
model = T2I()
|
||||||
|
model.text2image("a unicorn in manhattan")
|
||||||
|
~~~~
|
||||||
|
|
||||||
|
Please see ldm/simplet2i.py for more information.
|
||||||
|
|
||||||
|
## Interactive command-line interface similar to the Discord bot
|
||||||
|
|
||||||
|
There is now a command-line script, located in scripts/dream.py, which
|
||||||
|
provides an interactive interface to image generation similar to
|
||||||
|
the "dream mothership" bot that Stable AI provided on its Discord
|
||||||
|
server. The advantage of this is that the lengthy model
|
||||||
|
initialization only happens once. After that image generation is
|
||||||
|
fast.
|
||||||
|
|
||||||
|
Note that this has only been tested in the Linux environment!
|
||||||
|
|
||||||
|
(ldm) ~/stable-diffusion$ ./scripts/dream.py
|
||||||
|
* Initializing, be patient...
|
||||||
|
|
||||||
|
Loading model from models/ldm/text2img-large/model.ckpt
|
||||||
|
LatentDiffusion: Running in eps-prediction mode
|
||||||
|
DiffusionWrapper has 872.30 M params.
|
||||||
|
making attention of type 'vanilla' with 512 in_channels
|
||||||
|
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
making attention of type 'vanilla' with 512 in_channels
|
||||||
|
Loading Bert tokenizer from "models/bert"
|
||||||
|
setting sampler to plms
|
||||||
|
|
||||||
|
* Initialization done! Awaiting your command...
|
||||||
|
dream> ashley judd riding a camel -n2
|
||||||
|
Outputs:
|
||||||
|
outputs/txt2img-samples/00009.png: "ashley judd riding a camel" -n2 -S 416354203
|
||||||
|
outputs/txt2img-samples/00010.png: "ashley judd riding a camel" -n2 -S 1362479620
|
||||||
|
|
||||||
|
Command-line arguments ("./scripts/dream.py -h") allow you to change
|
||||||
|
various defaults, and select between the mature stable-diffusion
|
||||||
|
weights (512x512) and the older (256x256) latent diffusion weights
|
||||||
|
(laion400m).
|
||||||
|
|
||||||
|
## No need for internet connectivity when loading the model
|
||||||
|
|
||||||
|
My development machine is a GPU node in a high-performance compute
|
||||||
|
cluster which has no connection to the internet. During model
|
||||||
|
initialization, stable-diffusion tries to download the Bert tokenizer
|
||||||
|
model from huggingface.co. This obviously didn't work for me.
|
||||||
|
|
||||||
|
Rather than set up a hugging face local hub, I found the most
|
||||||
|
expedient thing to do was to download the Bert tokenizer in advance,
|
||||||
|
and patch stable-diffusion to read it from the local disk. The steps
|
||||||
|
to do this are:
|
||||||
|
|
||||||
|
(ldm) ~/stable-diffusion$ mkdir ./models/bert
|
||||||
|
> python3
|
||||||
|
>>> from transformers import BertTokenizerFast
|
||||||
|
>>> model = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||||
|
>>> model.save_pretrained("./models/bert")
|
||||||
|
|
||||||
|
(Make sure you are in the stable-diffusion directory when you do
|
||||||
|
this!)
|
||||||
|
|
||||||
|
If you don't like this change, just copy over the file
|
||||||
|
ldm/modules/encoders/modules.py from the CompVis/stable-diffusion
|
||||||
|
repository.
|
||||||
|
|
||||||
|
## Minor fixes
|
||||||
|
|
||||||
|
I added the requirement for torchmetrics to environment.yaml.
|
||||||
|
|
||||||
|
## Installation and support
|
||||||
|
|
||||||
|
Follow the directions from the original README, which starts below, to
|
||||||
|
configure the environment and install requirements. For support,
|
||||||
|
please use this repository's GitHub Issues tracking service.
|
||||||
|
|
||||||
|
Author: Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||||
|
|
||||||
|
# Original README from CompViz/stable-diffusion
|
||||||
*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*
|
*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*
|
||||||
|
|
||||||
[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)<br/>
|
[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)<br/>
|
||||||
|
@ -135,7 +135,7 @@ class DDIMSampler(object):
|
|||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps, dynamic_ncols=True)
|
||||||
|
|
||||||
for i, step in enumerate(iterator):
|
for i, step in enumerate(iterator):
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
@ -238,4 +238,4 @@ class DDIMSampler(object):
|
|||||||
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning)
|
unconditional_conditioning=unconditional_conditioning)
|
||||||
return x_dec
|
return x_dec
|
||||||
|
@ -255,7 +255,7 @@ class DDPM(pl.LightningModule):
|
|||||||
b = shape[0]
|
b = shape[0]
|
||||||
img = torch.randn(shape, device=device)
|
img = torch.randn(shape, device=device)
|
||||||
intermediates = [img]
|
intermediates = [img]
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps, dynamic_ncols=True):
|
||||||
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
|
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
|
||||||
clip_denoised=self.clip_denoised)
|
clip_denoised=self.clip_denoised)
|
||||||
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
|
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
|
||||||
|
@ -92,7 +92,7 @@ class PLMSSampler(object):
|
|||||||
# sampling
|
# sampling
|
||||||
C, H, W = shape
|
C, H, W = shape
|
||||||
size = (batch_size, C, H, W)
|
size = (batch_size, C, H, W)
|
||||||
print(f'Data shape for PLMS sampling is {size}')
|
# print(f'Data shape for PLMS sampling is {size}')
|
||||||
|
|
||||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
@ -134,9 +134,9 @@ class PLMSSampler(object):
|
|||||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
# print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps, dynamic_ncols=True)
|
||||||
old_eps = []
|
old_eps = []
|
||||||
|
|
||||||
for i, step in enumerate(iterator):
|
for i, step in enumerate(iterator):
|
||||||
|
@ -55,7 +55,10 @@ class BERTTokenizer(AbstractEncoder):
|
|||||||
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||||
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
fn = 'models/bert'
|
||||||
|
print(f'Loading Bert tokenizer from "{fn}"')
|
||||||
|
# self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||||
|
self.tokenizer = BertTokenizerFast.from_pretrained(fn,local_files_only=True)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.vq_interface = vq_interface
|
self.vq_interface = vq_interface
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
@ -231,4 +234,5 @@ class FrozenClipImageEmbedder(nn.Module):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from ldm.util import count_params
|
from ldm.util import count_params
|
||||||
model = FrozenCLIPEmbedder()
|
model = FrozenCLIPEmbedder()
|
||||||
count_params(model, verbose=True)
|
count_params(model, verbose=True)
|
||||||
|
|
||||||
|
258
ldm/simplet2i.py
Normal file
258
ldm/simplet2i.py
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
|
|
||||||
|
Example Usage:
|
||||||
|
|
||||||
|
from ldm.simplet2i import T2I
|
||||||
|
# Create an object with default values
|
||||||
|
t2i = T2I(outdir = <path> // outputs/txt2img-samples
|
||||||
|
model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
|
config = <path> // default="configs/stable-diffusion/v1-inference.yaml
|
||||||
|
batch = <integer> // 1
|
||||||
|
steps = <integer> // 50
|
||||||
|
seed = <integer> // current system time
|
||||||
|
sampler = ['ddim','plms'] // ddim
|
||||||
|
grid = <boolean> // false
|
||||||
|
width = <integer> // image width, multiple of 64 (512)
|
||||||
|
height = <integer> // image height, multiple of 64 (512)
|
||||||
|
cfg_scale = <float> // unconditional guidance scale (7.5)
|
||||||
|
fixed_code = <boolean> // False
|
||||||
|
)
|
||||||
|
# do the slow model initialization
|
||||||
|
t2i.load_model()
|
||||||
|
|
||||||
|
# Do the fast inference & image generation. Any options passed here
|
||||||
|
# override the default values assigned during class initialization
|
||||||
|
# Will call load_model() if the model was not previously loaded.
|
||||||
|
t2i.txt2img(prompt = <string> // required
|
||||||
|
// the remaining option arguments override constructur value when present
|
||||||
|
outdir = <path>
|
||||||
|
iterations = <integer>
|
||||||
|
batch = <integer>
|
||||||
|
steps = <integer>
|
||||||
|
seed = <integer>
|
||||||
|
sampler = ['ddim','plms']
|
||||||
|
grid = <boolean>
|
||||||
|
width = <integer>
|
||||||
|
height = <integer>
|
||||||
|
cfg_scale = <float>
|
||||||
|
) -> boolean
|
||||||
|
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from itertools import islice
|
||||||
|
from einops import rearrange
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
from torch import autocast
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
|
from time import time
|
||||||
|
from math import sqrt
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
|
||||||
|
class T2I:
|
||||||
|
"""T2I class
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
outdir
|
||||||
|
model
|
||||||
|
config
|
||||||
|
iterations
|
||||||
|
batch
|
||||||
|
steps
|
||||||
|
seed
|
||||||
|
sampler
|
||||||
|
grid
|
||||||
|
width
|
||||||
|
height
|
||||||
|
cfg_scale
|
||||||
|
fixed_code
|
||||||
|
latent_channels
|
||||||
|
downsampling_factor
|
||||||
|
precision
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
outdir="outputs/txt2img-samples",
|
||||||
|
batch=1,
|
||||||
|
iterations = 1,
|
||||||
|
width=256, # change to 512 for stable diffusion
|
||||||
|
height=256, # change to 512 for stable diffusion
|
||||||
|
grid=False,
|
||||||
|
steps=50,
|
||||||
|
seed=None,
|
||||||
|
cfg_scale=7.5,
|
||||||
|
weights="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||||
|
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml",
|
||||||
|
sampler="plms",
|
||||||
|
latent_channels=4,
|
||||||
|
downsampling_factor=8,
|
||||||
|
ddim_eta=0.0, # deterministic
|
||||||
|
fixed_code=False,
|
||||||
|
precision='autocast'
|
||||||
|
):
|
||||||
|
self.outdir = outdir
|
||||||
|
self.batch = batch
|
||||||
|
self.iterations = iterations
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.grid = grid
|
||||||
|
self.steps = steps
|
||||||
|
self.cfg_scale = cfg_scale
|
||||||
|
self.weights = weights
|
||||||
|
self.config = config
|
||||||
|
self.sampler_name = sampler
|
||||||
|
self.fixed_code = fixed_code
|
||||||
|
self.latent_channels = latent_channels
|
||||||
|
self.downsampling_factor = downsampling_factor
|
||||||
|
self.ddim_eta = ddim_eta
|
||||||
|
self.precision = precision
|
||||||
|
self.model = None # empty for now
|
||||||
|
self.sampler = None
|
||||||
|
if seed is None:
|
||||||
|
self.seed = self._new_seed()
|
||||||
|
else:
|
||||||
|
self.seed = seed
|
||||||
|
def txt2img(self,prompt,outdir=None,batch=None,iterations=None,
|
||||||
|
steps=None,seed=None,grid=None,width=None,height=None,
|
||||||
|
cfg_scale=None,ddim_eta=None):
|
||||||
|
""" generate an image from the prompt, writing iteration images into the outdir """
|
||||||
|
outdir = outdir or self.outdir
|
||||||
|
steps = steps or self.steps
|
||||||
|
seed = seed or self.seed
|
||||||
|
width = width or self.width
|
||||||
|
height = height or self.height
|
||||||
|
cfg_scale = cfg_scale or self.cfg_scale
|
||||||
|
ddim_eta = ddim_eta or self.ddim_eta
|
||||||
|
batch = batch or self.batch
|
||||||
|
iterations = iterations or self.iterations
|
||||||
|
if batch > 1:
|
||||||
|
iterations = 1
|
||||||
|
|
||||||
|
model = self.load_model() # will instantiate the model or return it from cache
|
||||||
|
|
||||||
|
if (grid is None):
|
||||||
|
grid = self.grid
|
||||||
|
data = [batch * [prompt]]
|
||||||
|
|
||||||
|
# make directories and establish names for the output files
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
base_count = len(os.listdir(outdir))-1
|
||||||
|
|
||||||
|
start_code = None
|
||||||
|
if self.fixed_code:
|
||||||
|
start_code = torch.randn([batch,
|
||||||
|
self.latent_channels,
|
||||||
|
height // self.downsampling_factor,
|
||||||
|
width // self.downsampling_factor],
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
precision_scope = autocast if self.precision=="autocast" else nullcontext
|
||||||
|
sampler = self.sampler
|
||||||
|
images = list()
|
||||||
|
seeds = list()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with precision_scope("cuda"):
|
||||||
|
with model.ema_scope():
|
||||||
|
all_samples = list()
|
||||||
|
for n in trange(iterations, desc="Sampling"):
|
||||||
|
seed_everything(seed)
|
||||||
|
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||||
|
uc = None
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
uc = model.get_learned_conditioning(batch * [""])
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||||
|
samples_ddim, _ = sampler.sample(S=steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=batch,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=cfg_scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=ddim_eta,
|
||||||
|
x_T=start_code)
|
||||||
|
|
||||||
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
for x_sample in x_samples_ddim:
|
||||||
|
if grid:
|
||||||
|
all_samples.append(x_samples_ddim)
|
||||||
|
seeds.append(seed)
|
||||||
|
else:
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
filename = os.path.join(outdir, f"{base_count:05}.png")
|
||||||
|
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
|
||||||
|
images.append([filename,seed])
|
||||||
|
base_count += 1
|
||||||
|
seed = self._new_seed()
|
||||||
|
|
||||||
|
if grid:
|
||||||
|
n_rows = int(sqrt(batch * iterations))
|
||||||
|
# save as grid
|
||||||
|
grid = torch.stack(all_samples, 0)
|
||||||
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
|
|
||||||
|
# to image
|
||||||
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
|
filename = os.path.join(outdir, f"{base_count:05}.png")
|
||||||
|
Image.fromarray(grid.astype(np.uint8)).save(filename)
|
||||||
|
for s in seeds:
|
||||||
|
images.append([filename,s])
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def _new_seed(self):
|
||||||
|
self.seed = random.randrange(0,np.iinfo(np.uint32).max)
|
||||||
|
return self.seed
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
""" Load and initialize the model from configuration variables passed at object creation time """
|
||||||
|
if self.model is None:
|
||||||
|
seed_everything(self.seed)
|
||||||
|
try:
|
||||||
|
config = OmegaConf.load(self.config)
|
||||||
|
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
model = self._load_model_from_config(config,self.weights)
|
||||||
|
self.model = model.to(self.device)
|
||||||
|
except AttributeError:
|
||||||
|
raise SystemExit
|
||||||
|
|
||||||
|
if self.sampler_name=='plms':
|
||||||
|
print("setting sampler to plms")
|
||||||
|
self.sampler = PLMSSampler(self.model)
|
||||||
|
elif self.sampler_name == 'ddim':
|
||||||
|
print("setting sampler to ddim")
|
||||||
|
self.sampler = DDIMSampler(self.model)
|
||||||
|
else:
|
||||||
|
print(f"unsupported sampler {self.sampler_name}, defaulting to plms")
|
||||||
|
self.sampler = PLMSSampler(self.model)
|
||||||
|
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def _load_model_from_config(self, config, ckpt):
|
||||||
|
print(f"Loading model from {ckpt}")
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
@ -12,8 +12,6 @@ from queue import Queue
|
|||||||
|
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
|
||||||
def log_txt_as_img(wh, xc, size=10):
|
def log_txt_as_img(wh, xc, size=10):
|
||||||
# wh a tuple of (width, height)
|
# wh a tuple of (width, height)
|
||||||
# xc a list of captions to plot
|
# xc a list of captions to plot
|
||||||
|
144
scripts/dream.py
Executable file
144
scripts/dream.py
Executable file
@ -0,0 +1,144 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import readline
|
||||||
|
import argparse
|
||||||
|
import shlex
|
||||||
|
import atexit
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
def main():
|
||||||
|
arg_parser = create_argv_parser()
|
||||||
|
opt = arg_parser.parse_args()
|
||||||
|
if opt.laion400m:
|
||||||
|
# defaults suitable to the older latent diffusion weights
|
||||||
|
width = 256
|
||||||
|
height = 256
|
||||||
|
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
|
||||||
|
weights = "models/ldm/text2img-large/model.ckpt"
|
||||||
|
else:
|
||||||
|
# some defaults suitable for stable diffusion weights
|
||||||
|
width = 512
|
||||||
|
height = 512
|
||||||
|
config = "configs/stable-diffusion/v1-inference.yaml"
|
||||||
|
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||||
|
|
||||||
|
# command line history will be stored in a file called "~/.dream_history"
|
||||||
|
load_history()
|
||||||
|
|
||||||
|
print("* Initializing, be patient...\n")
|
||||||
|
from pytorch_lightning import logging
|
||||||
|
from ldm.simplet2i import T2I
|
||||||
|
|
||||||
|
# creating a simple text2image object with a handful of
|
||||||
|
# defaults passed on the command line.
|
||||||
|
# additional parameters will be added (or overriden) during
|
||||||
|
# the user input loop
|
||||||
|
t2i = T2I(width=width,
|
||||||
|
height=height,
|
||||||
|
batch=opt.batch,
|
||||||
|
outdir=opt.outdir,
|
||||||
|
sampler=opt.sampler,
|
||||||
|
weights=weights,
|
||||||
|
config=config)
|
||||||
|
|
||||||
|
# gets rid of annoying messages about random seed
|
||||||
|
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
# preload the model
|
||||||
|
t2i.load_model()
|
||||||
|
print("\n* Initialization done! Awaiting your command...")
|
||||||
|
|
||||||
|
log_path = path.join(opt.outdir,"dream_log.txt")
|
||||||
|
with open(log_path,'a') as log:
|
||||||
|
cmd_parser = create_cmd_parser()
|
||||||
|
main_loop(t2i,cmd_parser,log)
|
||||||
|
log.close()
|
||||||
|
|
||||||
|
def main_loop(t2i,parser,log):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
command = input("dream> ")
|
||||||
|
except EOFError:
|
||||||
|
print("goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
elements = shlex.split(command)
|
||||||
|
switches = ['']
|
||||||
|
switches_started = False
|
||||||
|
|
||||||
|
for el in elements:
|
||||||
|
if el[0]=='-' and not switches_started:
|
||||||
|
switches_started = True
|
||||||
|
if switches_started:
|
||||||
|
switches.append(el)
|
||||||
|
else:
|
||||||
|
switches[0] += el
|
||||||
|
switches[0] += ' '
|
||||||
|
switches[0] = switches[0][:len(switches[0])-1]
|
||||||
|
try:
|
||||||
|
opt = parser.parse_args(switches)
|
||||||
|
except SystemExit:
|
||||||
|
parser.print_help()
|
||||||
|
pass
|
||||||
|
results = t2i.txt2img(**vars(opt))
|
||||||
|
print("Outputs:")
|
||||||
|
for r in results:
|
||||||
|
log_message = " ".join([' ',str(r[0])+':',
|
||||||
|
f'"{switches[0]}"',
|
||||||
|
*switches[1:],f'-S {r[1]}'])
|
||||||
|
print(log_message)
|
||||||
|
log.write(log_message+"\n")
|
||||||
|
log.flush()
|
||||||
|
|
||||||
|
def create_argv_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Parse script's command line args")
|
||||||
|
parser.add_argument("--laion400m",
|
||||||
|
"--latent_diffusion",
|
||||||
|
"-l",
|
||||||
|
dest='laion400m',
|
||||||
|
action='store_true',
|
||||||
|
help="fallback to the latent diffusion (LAION4400M) weights and config")
|
||||||
|
parser.add_argument('-n','--iterations',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="number of images to produce per sampling (overrides -n<iterations>, faster but doesn't produce individual seeds)")
|
||||||
|
parser.add_argument('-b','--batch',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="number of images to produce per sampling (currently broken")
|
||||||
|
parser.add_argument('--sampler',
|
||||||
|
choices=['plms','ddim'],
|
||||||
|
default='plms',
|
||||||
|
help="which sampler to use")
|
||||||
|
parser.add_argument('-o',
|
||||||
|
'--outdir',
|
||||||
|
type=str,
|
||||||
|
default="outputs/txt2img-samples",
|
||||||
|
help="directory in which to place generated images and a log of prompts and seeds")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def create_cmd_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Parse terminal input in a discord 'dreambot' fashion")
|
||||||
|
parser.add_argument('prompt')
|
||||||
|
parser.add_argument('-s','--steps',type=int,help="number of steps")
|
||||||
|
parser.add_argument('-S','--seed',type=int,help="image seed")
|
||||||
|
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
|
||||||
|
parser.add_argument('-b','--batch',type=int,default=1,help="number of images to produce per sampling (currently broken)")
|
||||||
|
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
|
||||||
|
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
|
||||||
|
parser.add_argument('-C','--cfg_scale',type=float,help="prompt configuration scale (7.5)")
|
||||||
|
parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def load_history():
|
||||||
|
histfile = path.join(path.expanduser('~'),".dream_history")
|
||||||
|
try:
|
||||||
|
readline.read_history_file(histfile)
|
||||||
|
readline.set_history_length(1000)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
atexit.register(readline.write_history_file,histfile)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user