mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactor generate.py and dream.py (#534)
* revert inadvertent change of conda env name (#528) * Refactor generate.py and dream.py * config file path (models.yaml) is parsed inside Generate() to simplify API * Better handling of keyboard interrupts in file loading mode vs interactive * Removed oodles of unused variables. * move nonfunctional inpainting out of the scripts directory * fix ugly ddim tqdm formatting
This commit is contained in:
parent
d15c75ecae
commit
e6179af46a
204
ldm/generate.py
204
ldm/generate.py
@ -17,7 +17,7 @@ import transformers
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image, ImageOps
|
||||
from torch import nn
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning import seed_everything, logging
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
@ -35,7 +35,7 @@ Example Usage:
|
||||
from ldm.generate import Generate
|
||||
|
||||
# Create an object with default values
|
||||
gr = Generate()
|
||||
gr = Generate('stable-diffusion-1.4')
|
||||
|
||||
# do the slow model initialization
|
||||
gr.load_model()
|
||||
@ -79,16 +79,17 @@ still work.
|
||||
|
||||
The full list of arguments to Generate() are:
|
||||
gr = Generate(
|
||||
# these values are set once and shouldn't be changed
|
||||
conf = path to configuration file ('configs/models.yaml')
|
||||
model = symbolic name of the model in the configuration file
|
||||
full_precision = False
|
||||
|
||||
# this value is sticky and maintained between generation calls
|
||||
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
|
||||
# these are deprecated - use conf and model instead
|
||||
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
|
||||
config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
|
||||
iterations = <integer> // how many times to run the sampling (1)
|
||||
steps = <integer> // 50
|
||||
seed = <integer> // current system time
|
||||
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
grid = <boolean> // false
|
||||
width = <integer> // image width, multiple of 64 (512)
|
||||
height = <integer> // image height, multiple of 64 (512)
|
||||
cfg_scale = <float> // condition-free guidance scale (7.5)
|
||||
config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
|
||||
)
|
||||
|
||||
"""
|
||||
@ -101,66 +102,62 @@ class Generate:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iterations = 1,
|
||||
steps = 50,
|
||||
cfg_scale = 7.5,
|
||||
weights = 'models/ldm/stable-diffusion-v1/model.ckpt',
|
||||
config = 'configs/stable-diffusion/v1-inference.yaml',
|
||||
grid = False,
|
||||
width = 512,
|
||||
height = 512,
|
||||
model = 'stable-diffusion-1.4',
|
||||
conf = 'configs/models.yaml',
|
||||
embedding_path = None,
|
||||
sampler_name = 'k_lms',
|
||||
ddim_eta = 0.0, # deterministic
|
||||
full_precision = False,
|
||||
strength = 0.75, # default in scripts/img2img.py
|
||||
seamless = False,
|
||||
embedding_path = None,
|
||||
device_type = 'cuda',
|
||||
ignore_ctrl_c = False,
|
||||
# these are deprecated; if present they override values in the conf file
|
||||
weights = None,
|
||||
config = None,
|
||||
):
|
||||
self.iterations = iterations
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.steps = steps
|
||||
self.cfg_scale = cfg_scale
|
||||
self.weights = weights
|
||||
self.config = config
|
||||
self.sampler_name = sampler_name
|
||||
self.grid = grid
|
||||
self.ddim_eta = ddim_eta
|
||||
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
|
||||
self.strength = strength
|
||||
self.seamless = seamless
|
||||
self.embedding_path = embedding_path
|
||||
self.device_type = device_type
|
||||
self.ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here...
|
||||
self.model = None # empty for now
|
||||
self.sampler = None
|
||||
self.device = None
|
||||
self.generators = {}
|
||||
self.base_generator = None
|
||||
self.seed = None
|
||||
models = OmegaConf.load(conf)
|
||||
mconfig = models[model]
|
||||
self.weights = mconfig.weights if weights is None else weights
|
||||
self.config = mconfig.config if config is None else config
|
||||
self.height = mconfig.height
|
||||
self.width = mconfig.width
|
||||
self.iterations = 1
|
||||
self.steps = 50
|
||||
self.cfg_scale = 7.5
|
||||
self.sampler_name = sampler_name
|
||||
self.ddim_eta = 0.0 # same seed always produces same image
|
||||
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
|
||||
self.strength = 0.75
|
||||
self.seamless = False
|
||||
self.embedding_path = embedding_path
|
||||
self.model = None # empty for now
|
||||
self.sampler = None
|
||||
self.device = None
|
||||
self.session_peakmem = None
|
||||
self.generators = {}
|
||||
self.base_generator = None
|
||||
self.seed = None
|
||||
|
||||
if device_type == 'cuda' and not torch.cuda.is_available():
|
||||
device_type = choose_torch_device()
|
||||
print(">> cuda not available, using device", device_type)
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
# it wasn't actually doing anything. This logic could be reinstated.
|
||||
device_type = choose_torch_device()
|
||||
self.device = torch.device(device_type)
|
||||
|
||||
# for VRAM usage statistics
|
||||
device_type = choose_torch_device()
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# gets rid of annoying messages about random seed
|
||||
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
||||
|
||||
def prompt2png(self, prompt, outdir, **kwargs):
|
||||
"""
|
||||
Takes a prompt and an output directory, writes out the requested number
|
||||
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
|
||||
Optional named arguments are the same as those passed to Generate and prompt2image()
|
||||
"""
|
||||
results = self.prompt2image(prompt, **kwargs)
|
||||
results = self.prompt2image(prompt, **kwargs)
|
||||
pngwriter = PngWriter(outdir)
|
||||
prefix = pngwriter.unique_prefix()
|
||||
outputs = []
|
||||
prefix = pngwriter.unique_prefix()
|
||||
outputs = []
|
||||
for image, seed in results:
|
||||
name = f'{prefix}.{seed}.png'
|
||||
path = pngwriter.save_image_and_prompt_to_png(
|
||||
@ -183,33 +180,35 @@ class Generate:
|
||||
self,
|
||||
# these are common
|
||||
prompt,
|
||||
iterations = None,
|
||||
steps = None,
|
||||
seed = None,
|
||||
cfg_scale = None,
|
||||
ddim_eta = None,
|
||||
skip_normalize = False,
|
||||
image_callback = None,
|
||||
step_callback = None,
|
||||
width = None,
|
||||
height = None,
|
||||
sampler_name = None,
|
||||
seamless = False,
|
||||
log_tokenization= False,
|
||||
with_variations = None,
|
||||
variation_amount = 0.0,
|
||||
iterations = None,
|
||||
steps = None,
|
||||
seed = None,
|
||||
cfg_scale = None,
|
||||
ddim_eta = None,
|
||||
skip_normalize = False,
|
||||
image_callback = None,
|
||||
step_callback = None,
|
||||
width = None,
|
||||
height = None,
|
||||
sampler_name = None,
|
||||
seamless = False,
|
||||
log_tokenization = False,
|
||||
with_variations = None,
|
||||
variation_amount = 0.0,
|
||||
# these are specific to img2img and inpaint
|
||||
init_img = None,
|
||||
init_mask = None,
|
||||
fit = False,
|
||||
strength = None,
|
||||
init_img = None,
|
||||
init_mask = None,
|
||||
fit = False,
|
||||
strength = None,
|
||||
# these are specific to embiggen (which also relies on img2img args)
|
||||
embiggen = None,
|
||||
embiggen_tiles = None,
|
||||
# these are specific to GFPGAN/ESRGAN
|
||||
gfpgan_strength= 0,
|
||||
save_original = False,
|
||||
upscale = None,
|
||||
gfpgan_strength = 0,
|
||||
save_original = False,
|
||||
upscale = None,
|
||||
# Set this True to handle KeyboardInterrupt internally
|
||||
catch_interrupts = False,
|
||||
**args,
|
||||
): # eat up additional cruft
|
||||
"""
|
||||
@ -262,10 +261,9 @@ class Generate:
|
||||
self.log_tokenization = log_tokenization
|
||||
with_variations = [] if with_variations is None else with_variations
|
||||
|
||||
model = (
|
||||
self.load_model()
|
||||
) # will instantiate the model or return it from cache
|
||||
|
||||
# will instantiate the model or return it from cache
|
||||
model = self.load_model()
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
|
||||
@ -281,7 +279,6 @@ class Generate:
|
||||
(embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None)
|
||||
), 'Embiggen requires an init/input image to be specified'
|
||||
|
||||
# check this logic - doesn't look right
|
||||
if len(with_variations) > 0 or variation_amount > 1.0:
|
||||
assert seed is not None,\
|
||||
'seed must be specified when using with_variations'
|
||||
@ -298,7 +295,7 @@ class Generate:
|
||||
self._set_sampler()
|
||||
|
||||
tic = time.time()
|
||||
if torch.cuda.is_available():
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
results = list()
|
||||
@ -307,9 +304,9 @@ class Generate:
|
||||
|
||||
try:
|
||||
uc, c = get_uc_and_c(
|
||||
prompt, model=self.model,
|
||||
prompt, model =self.model,
|
||||
skip_normalize=skip_normalize,
|
||||
log_tokens=self.log_tokenization
|
||||
log_tokens =self.log_tokenization
|
||||
)
|
||||
|
||||
(init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
|
||||
@ -352,27 +349,25 @@ class Generate:
|
||||
save_original = save_original,
|
||||
image_callback = image_callback)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print('*interrupted*')
|
||||
if not self.ignore_ctrl_c:
|
||||
raise KeyboardInterrupt
|
||||
print(
|
||||
'>> Partial results will be returned; if --grid was requested, nothing will be returned.'
|
||||
)
|
||||
except RuntimeError as e:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Could not generate image.')
|
||||
except KeyboardInterrupt:
|
||||
if catch_interrupts:
|
||||
print('**Interrupted** Partial results will be returned.')
|
||||
else:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
toc = time.time()
|
||||
print('>> Usage stats:')
|
||||
print(
|
||||
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
|
||||
)
|
||||
if torch.cuda.is_available() and self.device.type == 'cuda':
|
||||
if self._has_cuda():
|
||||
print(
|
||||
f'>> Max VRAM used for this generation:',
|
||||
'%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'Current VRAM utilization:'
|
||||
'Current VRAM utilization:',
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
@ -439,8 +434,7 @@ class Generate:
|
||||
if self.model is None:
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
try:
|
||||
config = OmegaConf.load(self.config)
|
||||
model = self._load_model_from_config(config, self.weights)
|
||||
model = self._load_model_from_config(self.config, self.weights)
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(
|
||||
self.embedding_path, self.full_precision
|
||||
@ -541,8 +535,11 @@ class Generate:
|
||||
|
||||
print(msg)
|
||||
|
||||
def _load_model_from_config(self, config, ckpt):
|
||||
print(f'>> Loading model from {ckpt}')
|
||||
# Be warned: config is the path to the model config file, not the dream conf file!
|
||||
# Also note that we can get config and weights from self, so why do we need to
|
||||
# pass them as args?
|
||||
def _load_model_from_config(self, config, weights):
|
||||
print(f'>> Loading model from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
device_type = choose_torch_device()
|
||||
@ -551,10 +548,11 @@ class Generate:
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
pl_sd = torch.load(ckpt, map_location='cpu')
|
||||
sd = pl_sd['state_dict']
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
c = OmegaConf.load(config)
|
||||
pl_sd = torch.load(weights, map_location='cpu')
|
||||
sd = pl_sd['state_dict']
|
||||
model = instantiate_from_config(c.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if self.full_precision:
|
||||
print(
|
||||
@ -573,7 +571,7 @@ class Generate:
|
||||
print(
|
||||
f'>> Model loaded in', '%4.2fs' % (toc - tic)
|
||||
)
|
||||
if device_type == 'cuda':
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
@ -710,3 +708,5 @@ class Generate:
|
||||
return width, height, resize_needed
|
||||
|
||||
|
||||
def _has_cuda(self):
|
||||
return self.device.type == 'cuda'
|
||||
|
@ -225,7 +225,7 @@ class DDIMSampler(object):
|
||||
total_steps = (
|
||||
timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
)
|
||||
print(f'Running DDIM Sampling with {total_steps} timesteps')
|
||||
print(f'\nRunning DDIM Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(
|
||||
time_range,
|
||||
|
@ -33,53 +33,35 @@ def main():
|
||||
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
|
||||
sys.exit(-1)
|
||||
|
||||
try:
|
||||
models = OmegaConf.load(opt.config)
|
||||
width = models[opt.model].width
|
||||
height = models[opt.model].height
|
||||
config = models[opt.model].config
|
||||
weights = models[opt.model].weights
|
||||
except (FileNotFoundError, IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
print('* Initializing, be patient...\n')
|
||||
sys.path.append('.')
|
||||
from pytorch_lightning import logging
|
||||
from ldm.generate import Generate
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# creating a simple text2image object with a handful of
|
||||
# creating a simple Generate object with a handful of
|
||||
# defaults passed on the command line.
|
||||
# additional parameters will be added (or overriden) during
|
||||
# the user input loop
|
||||
t2i = Generate(
|
||||
width=width,
|
||||
height=height,
|
||||
sampler_name=opt.sampler_name,
|
||||
weights=weights,
|
||||
full_precision=opt.full_precision,
|
||||
config=config,
|
||||
grid=opt.grid,
|
||||
# this is solely for recreating the prompt
|
||||
seamless=opt.seamless,
|
||||
embedding_path=opt.embedding_path,
|
||||
device_type=opt.device,
|
||||
ignore_ctrl_c=opt.infile is None,
|
||||
)
|
||||
try:
|
||||
gen = Generate(
|
||||
conf = opt.config,
|
||||
model = opt.model,
|
||||
sampler_name = opt.sampler_name,
|
||||
embedding_path = opt.embedding_path,
|
||||
full_precision = opt.full_precision,
|
||||
)
|
||||
except (FileNotFoundError, IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
# make sure the output directory exists
|
||||
if not os.path.exists(opt.outdir):
|
||||
os.makedirs(opt.outdir)
|
||||
|
||||
# gets rid of annoying messages about random seed
|
||||
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
||||
|
||||
# load the infile as a list of lines
|
||||
infile = None
|
||||
if opt.infile:
|
||||
@ -98,21 +80,23 @@ def main():
|
||||
print(">> changed to seamless tiling mode")
|
||||
|
||||
# preload the model
|
||||
t2i.load_model()
|
||||
gen.load_model()
|
||||
|
||||
if not infile:
|
||||
print(
|
||||
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)"
|
||||
)
|
||||
|
||||
cmd_parser = create_cmd_parser()
|
||||
# web server loops forever
|
||||
if opt.web:
|
||||
dream_server_loop(t2i, opt.host, opt.port, opt.outdir)
|
||||
else:
|
||||
main_loop(t2i, opt.outdir, opt.prompt_as_dir, cmd_parser, infile)
|
||||
dream_server_loop(gen, opt.host, opt.port, opt.outdir)
|
||||
sys.exit(0)
|
||||
|
||||
cmd_parser = create_cmd_parser()
|
||||
main_loop(gen, opt.outdir, opt.prompt_as_dir, cmd_parser, infile)
|
||||
|
||||
def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
|
||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||
def main_loop(gen, outdir, prompt_as_dir, parser, infile):
|
||||
"""prompt/read/execute loop"""
|
||||
done = False
|
||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||
@ -132,9 +116,6 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
|
||||
except EOFError:
|
||||
done = True
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
done = True
|
||||
continue
|
||||
|
||||
# skip empty lines
|
||||
if not command.strip():
|
||||
@ -184,6 +165,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
|
||||
if len(opt.prompt) == 0:
|
||||
print('Try again with a prompt!')
|
||||
continue
|
||||
|
||||
# retrieve previous value!
|
||||
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
|
||||
try:
|
||||
@ -204,8 +186,6 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
|
||||
opt.seed = None
|
||||
continue
|
||||
|
||||
do_grid = opt.grid or t2i.grid
|
||||
|
||||
if opt.with_variations is not None:
|
||||
# shotgun parsing, woo
|
||||
parts = []
|
||||
@ -258,11 +238,11 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
|
||||
file_writer = PngWriter(current_outdir)
|
||||
prefix = file_writer.unique_prefix()
|
||||
results = [] # list of filename, prompt pairs
|
||||
grid_images = dict() # seed -> Image, only used if `do_grid`
|
||||
grid_images = dict() # seed -> Image, only used if `opt.grid`
|
||||
|
||||
def image_writer(image, seed, upscaled=False):
|
||||
path = None
|
||||
if do_grid:
|
||||
if opt.grid:
|
||||
grid_images[seed] = image
|
||||
else:
|
||||
if upscaled and opt.save_original:
|
||||
@ -278,16 +258,16 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
|
||||
iter_opt.with_variations = opt.with_variations + this_variation
|
||||
iter_opt.variation_amount = 0
|
||||
normalized_prompt = PromptFormatter(
|
||||
t2i, iter_opt).normalize_prompt()
|
||||
gen, iter_opt).normalize_prompt()
|
||||
metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}'
|
||||
elif opt.with_variations is not None:
|
||||
normalized_prompt = PromptFormatter(
|
||||
t2i, opt).normalize_prompt()
|
||||
gen, opt).normalize_prompt()
|
||||
# use the original seed - the per-iteration value is the last variation-seed
|
||||
metadata_prompt = f'{normalized_prompt} -S{opt.seed}'
|
||||
else:
|
||||
normalized_prompt = PromptFormatter(
|
||||
t2i, opt).normalize_prompt()
|
||||
gen, opt).normalize_prompt()
|
||||
metadata_prompt = f'{normalized_prompt} -S{seed}'
|
||||
path = file_writer.save_image_and_prompt_to_png(
|
||||
image, metadata_prompt, filename)
|
||||
@ -296,16 +276,21 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
|
||||
results.append([path, metadata_prompt])
|
||||
last_results.append([path, seed])
|
||||
|
||||
t2i.prompt2image(image_callback=image_writer, **vars(opt))
|
||||
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
|
||||
gen.prompt2image(
|
||||
image_callback=image_writer,
|
||||
catch_interrupts=catch_ctrl_c,
|
||||
**vars(opt)
|
||||
)
|
||||
|
||||
if do_grid and len(grid_images) > 0:
|
||||
if opt.grid and len(grid_images) > 0:
|
||||
grid_img = make_grid(list(grid_images.values()))
|
||||
grid_seeds = list(grid_images.keys())
|
||||
first_seed = last_results[0][1]
|
||||
filename = f'{prefix}.{first_seed}.png'
|
||||
# TODO better metadata for grid images
|
||||
normalized_prompt = PromptFormatter(
|
||||
t2i, opt).normalize_prompt()
|
||||
gen, opt).normalize_prompt()
|
||||
metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -n{len(grid_images)} # {grid_seeds}'
|
||||
path = file_writer.save_image_and_prompt_to_png(
|
||||
grid_img, metadata_prompt, filename
|
||||
@ -337,11 +322,12 @@ def get_next_command(infile=None) -> str: # command string
|
||||
raise EOFError
|
||||
else:
|
||||
command = command.strip()
|
||||
print(f'#{command}')
|
||||
if len(command)>0:
|
||||
print(f'#{command}')
|
||||
return command
|
||||
|
||||
|
||||
def dream_server_loop(t2i, host, port, outdir):
|
||||
def dream_server_loop(gen, host, port, outdir):
|
||||
print('\n* --web was specified, starting web server...')
|
||||
# Change working directory to the stable-diffusion directory
|
||||
os.chdir(
|
||||
@ -349,7 +335,7 @@ def dream_server_loop(t2i, host, port, outdir):
|
||||
)
|
||||
|
||||
# Start server
|
||||
DreamServer.model = t2i
|
||||
DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for
|
||||
DreamServer.outdir = outdir
|
||||
dream_server = ThreadingDreamServer((host, port))
|
||||
print(">> Started Stable Diffusion dream server!")
|
||||
@ -519,13 +505,6 @@ def create_argv_parser():
|
||||
default='model',
|
||||
help='Indicates the Stable Diffusion model to use.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
'-d',
|
||||
type=str,
|
||||
default='cuda',
|
||||
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
default='stable-diffusion-1.4',
|
||||
|
Loading…
Reference in New Issue
Block a user