Merge branch 'development' into fix-prompts

This commit is contained in:
Damian at mba 2022-10-24 11:28:37 +02:00
commit 194c8e1c2e
12 changed files with 391 additions and 169 deletions

View File

@ -1,20 +1,22 @@
# This file describes the alternative machine learning models # This file describes the alternative machine learning models
# available to the dream script. # available to the dream script.
# #
# To add a new model, follow the examples below. Each # To add a new model, follow the examples below. Each
# model requires a model config file, a weights file, # model requires a model config file, a weights file,
# and the width and height of the images it # and the width and height of the images it
# was trained on. # was trained on.
laion400m:
config: configs/latent-diffusion/txt2img-1p4B-eval.yaml
weights: models/ldm/text2img-large/model.ckpt
description: Latent Diffusion LAION400M model
width: 256
height: 256
stable-diffusion-1.4: stable-diffusion-1.4:
config: configs/stable-diffusion/v1-inference.yaml config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/model.ckpt weights: models/ldm/stable-diffusion-v1/model.ckpt
description: Stable Diffusion inference model version 1.4 # vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512 description: Stable Diffusion inference model version 1.4
height: 512 default: true
width: 512
height: 512
stable-diffusion-1.5:
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
description: Stable Diffusion inference model version 1.5
width: 512
height: 512

View File

@ -8,7 +8,7 @@ hide:
## **Interactive Command Line Interface** ## **Interactive Command Line Interface**
The `invoke.py` script, located in `scripts/dream.py`, provides an interactive The `invoke.py` script, located in `scripts/`, provides an interactive
interface to image generation similar to the "invoke mothership" bot that Stable interface to image generation similar to the "invoke mothership" bot that Stable
AI provided on its Discord server. AI provided on its Discord server.

View File

@ -81,15 +81,18 @@ text2mask feature. The syntax is `!mask /path/to/image.png -tm <text>
It will generate three files: It will generate three files:
- The image with the selected area highlighted. - The image with the selected area highlighted.
- it will be named XXXXX.<imagename>.<prompt>.selected.png
- The image with the un-selected area highlighted. - The image with the un-selected area highlighted.
- it will be named XXXXX.<imagename>.<prompt>.deselected.png
- The image with the selected area converted into a black and white - The image with the selected area converted into a black and white
image according to the threshold level. image according to the threshold level
- it will be named XXXXX.<imagename>.<prompt>.masked.png
Note that none of these images are intended to be used as the mask The `.masked.png` file can then be directly passed to the `invoke>`
passed to invoke via `-M` and may give unexpected results if you try prompt in the CLI via the `-M` argument. Do not attempt this with
to use them this way. Instead, use `!mask` for testing that you are the `selected.png` or `deselected.png` files, as they contain some
selecting the right mask area, and then do inpainting using the transparency throughout the image and will not produce the desired
best selection term and threshold. results.
Here is an example of how `!mask` works: Here is an example of how `!mask` works:
@ -120,7 +123,7 @@ It looks like we selected the hair pretty well at the 0.5 threshold
let's have some fun: let's have some fun:
``` ```
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair 0.5 -C20 invoke> medusa with cobras -I ./test-pictures/curly.png -M 000019.curly.hair.masked.png -C20
>> loaded input image of size 512x512 from ./test-pictures/curly.png >> loaded input image of size 512x512 from ./test-pictures/curly.png
... ...
Outputs: Outputs:
@ -129,6 +132,13 @@ Outputs:
<img src="../assets/inpainting/000024.801380492.png"> <img src="../assets/inpainting/000024.801380492.png">
You can also skip the `!mask` creation step and just select the masked
region directly:
```
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair -C20
```
### Inpainting is not changing the masked region enough! ### Inpainting is not changing the masked region enough!
One of the things to understand about how inpainting works is that it One of the things to understand about how inpainting works is that it

View File

@ -56,23 +56,8 @@ torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli) torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial) torch.multinomial = fix_func(torch.multinomial)
def fix_func(orig): # this is fallback model in case no default is defined
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): FALLBACK_MODEL_NAME='stable-diffusion-1.4'
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
"""Simplified text to image API for stable diffusion/latent diffusion """Simplified text to image API for stable diffusion/latent diffusion
@ -126,12 +111,13 @@ still work.
The full list of arguments to Generate() are: The full list of arguments to Generate() are:
gr = Generate( gr = Generate(
# these values are set once and shouldn't be changed # these values are set once and shouldn't be changed
conf = path to configuration file ('configs/models.yaml') conf:str = path to configuration file ('configs/models.yaml')
model = symbolic name of the model in the configuration file model:str = symbolic name of the model in the configuration file
precision = float precision to be used precision:float = float precision to be used
safety_checker:bool = activate safety checker [False]
# this value is sticky and maintained between generation calls # this value is sticky and maintained between generation calls
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms sampler_name:str = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
# these are deprecated - use conf and model instead # these are deprecated - use conf and model instead
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt') weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
@ -148,7 +134,7 @@ class Generate:
def __init__( def __init__(
self, self,
model = 'stable-diffusion-1.4', model = None,
conf = 'configs/models.yaml', conf = 'configs/models.yaml',
embedding_path = None, embedding_path = None,
sampler_name = 'k_lms', sampler_name = 'k_lms',
@ -164,7 +150,6 @@ class Generate:
free_gpu_mem=False, free_gpu_mem=False,
): ):
mconfig = OmegaConf.load(conf) mconfig = OmegaConf.load(conf)
self.model_name = model
self.height = None self.height = None
self.width = None self.width = None
self.model_cache = None self.model_cache = None
@ -211,6 +196,7 @@ class Generate:
# model caching system for fast switching # model caching system for fast switching
self.model_cache = ModelCache(mconfig,self.device,self.precision) self.model_cache = ModelCache(mconfig,self.device,self.precision)
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
# for VRAM usage statistics # for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
@ -287,6 +273,8 @@ class Generate:
upscale = None, upscale = None,
# this is specific to inpainting and causes more extreme inpainting # this is specific to inpainting and causes more extreme inpainting
inpaint_replace = 0.0, inpaint_replace = 0.0,
# This will help match inpainted areas to the original image more smoothly
mask_blur_radius: int = 8,
# Set this True to handle KeyboardInterrupt internally # Set this True to handle KeyboardInterrupt internally
catch_interrupts = False, catch_interrupts = False,
hires_fix = False, hires_fix = False,
@ -407,7 +395,7 @@ class Generate:
log_tokens =self.log_tokenization log_tokens =self.log_tokenization
) )
init_image,mask_image = self._make_images( init_image, mask_image = self._make_images(
init_img, init_img,
init_mask, init_mask,
width, width,
@ -454,6 +442,7 @@ class Generate:
embiggen=embiggen, embiggen=embiggen,
embiggen_tiles=embiggen_tiles, embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace, inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius
) )
if init_color: if init_color:
@ -572,16 +561,19 @@ class Generate:
from ldm.invoke.restoration.outcrop import Outcrop from ldm.invoke.restoration.outcrop import Outcrop
extend_instructions = {} extend_instructions = {}
for direction,pixels in _pairwise(opt.outcrop): for direction,pixels in _pairwise(opt.outcrop):
extend_instructions[direction]=int(pixels) try:
extend_instructions[direction]=int(pixels)
restorer = Outcrop(image,self,) except ValueError:
return restorer.process ( print(f'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"')
extend_instructions, if len(extend_instructions)>0:
opt = opt, restorer = Outcrop(image,self,)
orig_opt = args, return restorer.process (
image_callback = callback, extend_instructions,
prefix = prefix, opt = opt,
) orig_opt = args,
image_callback = callback,
prefix = prefix,
)
elif tool == 'embiggen': elif tool == 'embiggen':
# fetch the metadata from the image # fetch the metadata from the image
@ -645,23 +637,22 @@ class Generate:
# if image has a transparent area and no mask was provided, then try to generate mask # if image has a transparent area and no mask was provided, then try to generate mask
if self._has_transparency(image): if self._has_transparency(image):
self._transparency_check_and_warning(image, mask) self._transparency_check_and_warning(image, mask)
# this returns a torch tensor
init_mask = self._create_init_mask(image, width, height, fit=fit) init_mask = self._create_init_mask(image, width, height, fit=fit)
if (image.width * image.height) > (self.width * self.height) and self.size_matters: if (image.width * image.height) > (self.width * self.height) and self.size_matters:
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
self.size_matters = False self.size_matters = False
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor init_image = self._create_init_image(image,width,height,fit=fit)
if mask: if mask:
mask_image = self._load_img(mask) # this returns an Image mask_image = self._load_img(mask)
init_mask = self._create_init_mask(mask_image,width,height,fit=fit) init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
elif text_mask: elif text_mask:
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
return init_image, init_mask return init_image,init_mask
def _make_base(self): def _make_base(self):
if not self.generators.get('base'): if not self.generators.get('base'):
@ -717,8 +708,7 @@ class Generate:
model_data = self.model_cache.get_model(model_name) model_data = self.model_cache.get_model(model_name)
if model_data is None or len(model_data) == 0: if model_data is None or len(model_data) == 0:
print(f'** Model switch failed **') return None
return self.model
self.model = model_data['model'] self.model = model_data['model']
self.width = model_data['width'] self.width = model_data['width']
@ -879,46 +869,31 @@ class Generate:
def _create_init_image(self, image, width, height, fit=True): def _create_init_image(self, image, width, height, fit=True):
image = image.convert('RGB') image = image.convert('RGB')
if fit: image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
image = self._fit_image(image, (width, height)) return image
else:
image = self._squeeze_image(image)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.0 * image - 1.0
return image.to(self.device)
def _create_init_mask(self, image, width, height, fit=True): def _create_init_mask(self, image, width, height, fit=True):
# convert into a black/white mask # convert into a black/white mask
image = self._image_to_mask(image) image = self._image_to_mask(image)
image = image.convert('RGB') image = image.convert('RGB')
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
# now we adjust the size return image
if fit:
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
image = image.resize((image.width//downsampling, image.height //
downsampling), resample=Image.Resampling.NEAREST)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image.to(self.device)
# The mask is expected to have the region to be inpainted # The mask is expected to have the region to be inpainted
# with alpha transparency. It converts it into a black/white # with alpha transparency. It converts it into a black/white
# image with the transparent part black. # image with the transparent part black.
def _image_to_mask(self, mask_image, invert=False) -> Image: def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image:
# Obtain the mask from the transparency channel # Obtain the mask from the transparency channel
mask = Image.new(mode="L", size=mask_image.size, color=255) if mask_image.mode == 'L':
mask.putdata(mask_image.getdata(band=3)) mask = mask_image
else:
# Obtain the mask from the transparency channel
mask = Image.new(mode="L", size=mask_image.size, color=255)
mask.putdata(mask_image.getdata(band=3))
if invert: if invert:
mask = ImageOps.invert(mask) mask = ImageOps.invert(mask)
return mask return mask
# TODO: The latter part of this method repeats code from _create_init_mask()
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image: def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
prompt = text_mask[0] prompt = text_mask[0]
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5 confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
@ -928,18 +903,8 @@ class Generate:
segmented = self.txt2mask.segment(image, prompt) segmented = self.txt2mask.segment(image, prompt)
mask = segmented.to_mask(float(confidence_level)) mask = segmented.to_mask(float(confidence_level))
mask = mask.convert('RGB') mask = mask.convert('RGB')
# now we adjust the size mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
if fit: return mask
mask = self._fit_image(mask, (width, height))
else:
mask = self._squeeze_image(mask)
mask = mask.resize((mask.width//downsampling, mask.height //
downsampling), resample=Image.Resampling.NEAREST)
mask = np.array(mask)
mask = mask.astype(np.float32) / 255.0
mask = mask[None].transpose(0, 3, 1, 2)
mask = torch.from_numpy(mask)
return mask.to(self.device)
def _has_transparency(self, image): def _has_transparency(self, image):
if image.info.get("transparency", None) is not None: if image.info.get("transparency", None) is not None:

View File

@ -113,8 +113,8 @@ PRECISION_CHOICES = [
] ]
# is there a way to pick this up during git commits? # is there a way to pick this up during git commits?
APP_ID = 'lstein/stable-diffusion' APP_ID = 'invoke-ai/InvokeAI'
APP_VERSION = 'v1.15' APP_VERSION = 'v2.02'
class ArgFormatter(argparse.RawTextHelpFormatter): class ArgFormatter(argparse.RawTextHelpFormatter):
# use defined argument order to display usage # use defined argument order to display usage
@ -172,6 +172,7 @@ class Args(object):
command = cmd_string.replace("'", "\\'") command = cmd_string.replace("'", "\\'")
try: try:
elements = shlex.split(command) elements = shlex.split(command)
elements = [x.replace("\\'","'") for x in elements]
except ValueError: except ValueError:
import sys, traceback import sys, traceback
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
@ -366,17 +367,16 @@ class Args(object):
deprecated_group.add_argument('--laion400m') deprecated_group.add_argument('--laion400m')
deprecated_group.add_argument('--weights') # deprecated deprecated_group.add_argument('--weights') # deprecated
model_group.add_argument( model_group.add_argument(
'--conf', '--config',
'-c', '-c',
'-conf', '-config',
dest='conf', dest='conf',
default='./configs/models.yaml', default='./configs/models.yaml',
help='Path to configuration file for alternate models.', help='Path to configuration file for alternate models.',
) )
model_group.add_argument( model_group.add_argument(
'--model', '--model',
default='stable-diffusion-1.4', help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
) )
model_group.add_argument( model_group.add_argument(
'--png_compression','-z', '--png_compression','-z',
@ -529,7 +529,7 @@ class Args(object):
formatter_class=ArgFormatter, formatter_class=ArgFormatter,
description= description=
""" """
*Image generation:* *Image generation*
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4 invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
*postprocessing* *postprocessing*
@ -544,6 +544,13 @@ class Args(object):
!history lists all the commands issued during the current session. !history lists all the commands issued during the current session.
!NN retrieves the NNth command from the history !NN retrieves the NNth command from the history
*Model manipulation*
!models -- list models in configs/models.yaml
!switch <model_name> -- switch to model named <model_name>
!import_model path/to/weights/file.ckpt -- adds a model to your config
!edit_model <model_name> -- edit a model's description
!del_model <model_name> -- delete a model
""" """
) )
render_group = parser.add_argument_group('General rendering') render_group = parser.add_argument_group('General rendering')
@ -840,7 +847,7 @@ def metadata_dumps(opt,
# remove any image keys not mentioned in RFC #266 # remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps', rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength', 'cfg_scale','threshold','perlin','step_number','width','height','extra','strength',
'init_img','init_mask'] 'init_img','init_mask','facetool','facetool_strength','upscale']
rfc_dict ={} rfc_dict ={}
@ -924,7 +931,7 @@ def metadata_loads(metadata) -> list:
for image in images: for image in images:
# repack the prompt and variations # repack the prompt and variations
if 'prompt' in image: if 'prompt' in image:
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']]) image['prompt'] = repack_prompt(image['prompt'])
if 'variations' in image: if 'variations' in image:
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']]) image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
# fix a bit of semantic drift here # fix a bit of semantic drift here
@ -932,12 +939,19 @@ def metadata_loads(metadata) -> list:
opt = Args() opt = Args()
opt._cmd_switches = Namespace(**image) opt._cmd_switches = Namespace(**image)
results.append(opt) results.append(opt)
except KeyError as e: except Exception as e:
import sys, traceback import sys, traceback
print('>> badly-formatted metadata',file=sys.stderr) print('>> could not read metadata',file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
return results return results
def repack_prompt(prompt_list:list)->str:
# in the common case of no weighting syntax, just return the prompt as is
if len(prompt_list) > 1:
return ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in prompt_list])
else:
return prompt_list[0]['prompt']
# image can either be a file path on disk or a base64-encoded # image can either be a file path on disk or a base64-encoded
# representation of the file's contents # representation of the file's contents
def calculate_init_img_hash(image_string): def calculate_init_img_hash(image_string):
@ -967,17 +981,17 @@ def sha256(path):
return sha.hexdigest() return sha.hexdigest()
def legacy_metadata_load(meta,pathname) -> Args: def legacy_metadata_load(meta,pathname) -> Args:
opt = Args()
if 'Dream' in meta and len(meta['Dream']) > 0: if 'Dream' in meta and len(meta['Dream']) > 0:
dream_prompt = meta['Dream'] dream_prompt = meta['Dream']
opt = Args()
opt.parse_cmd(dream_prompt) opt.parse_cmd(dream_prompt)
return opt
else: # if nothing else, we can get the seed else: # if nothing else, we can get the seed
match = re.search('\d+\.(\d+)',pathname) match = re.search('\d+\.(\d+)',pathname)
if match: if match:
seed = match.groups()[0] seed = match.groups()[0]
opt = Args()
opt.seed = seed opt.seed = seed
return opt else:
return None opt.prompt = ''
opt.seed = 0
return opt

View File

@ -4,9 +4,12 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
import torch import torch
import numpy as np import numpy as np
from ldm.invoke.devices import choose_autocast import PIL
from ldm.invoke.generator.base import Generator from torch import Tensor
from ldm.models.diffusion.ddim import DDIMSampler from PIL import Image
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Img2Img(Generator): class Img2Img(Generator):
@ -26,6 +29,9 @@ class Img2Img(Generator):
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
) )
if isinstance(init_image, PIL.Image.Image):
init_image = self._image_to_tensor(init_image)
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
with scope(self.model.device.type): with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding( self.init_latent = self.model.get_first_stage_encoding(
@ -71,3 +77,11 @@ class Img2Img(Generator):
shape = init_latent.shape shape = init_latent.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x return x
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if normalize:
image = 2.0 * image - 1.0
return image.to(self.model.device)

View File

@ -3,27 +3,55 @@ ldm.invoke.generator.inpaint descends from ldm.invoke.generator
''' '''
import torch import torch
import torchvision.transforms as T
import numpy as np import numpy as np
import cv2 as cv
import PIL
from PIL import Image, ImageFilter
from skimage.exposure.histogram_matching import match_histograms
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.invoke.devices import choose_autocast from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ksampler import KSampler from ldm.models.diffusion.ksampler import KSampler
from ldm.invoke.generator.base import downsampling
class Inpaint(Img2Img): class Inpaint(Img2Img):
def __init__(self, model, precision): def __init__(self, model, precision):
self.init_latent = None self.init_latent = None
self.pil_image = None
self.pil_mask = None
self.mask_blur_radius = 0
super().__init__(model, precision) super().__init__(model, precision)
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength, conditioning,init_image,mask_image,strength,
step_callback=None,inpaint_replace=False,**kwargs): mask_blur_radius: int = 8,
step_callback=None,inpaint_replace=False, **kwargs):
""" """
Returns a function returning an image derived from the prompt and Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at the initial image + mask. Return value depends on the seed at
the time you call it. kwargs are 'init_latent' and 'strength' the time you call it. kwargs are 'init_latent' and 'strength'
""" """
if isinstance(init_image, PIL.Image.Image):
self.pil_image = init_image
init_image = self._image_to_tensor(init_image)
if isinstance(mask_image, PIL.Image.Image):
self.pil_mask = mask_image
mask_image = mask_image.resize(
(
mask_image.width // downsampling,
mask_image.height // downsampling
),
resample=Image.Resampling.NEAREST
)
mask_image = self._image_to_tensor(mask_image,normalize=False)
self.mask_blur_radius = mask_blur_radius
# klms samplers not supported yet, so ignore previous sampler # klms samplers not supported yet, so ignore previous sampler
if isinstance(sampler,KSampler): if isinstance(sampler,KSampler):
print( print(
@ -78,10 +106,50 @@ class Inpaint(Img2Img):
mask = mask_image, mask = mask_image,
init_latent = self.init_latent init_latent = self.init_latent
) )
return self.sample_to_image(samples) return self.sample_to_image(samples)
return make_image return make_image
def sample_to_image(self, samples)->Image.Image:
gen_result = super().sample_to_image(samples).convert('RGB')
if self.pil_image is None or self.pil_mask is None:
return gen_result
pil_mask = self.pil_mask
pil_image = self.pil_image
mask_blur_radius = self.mask_blur_radius
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L')
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching.
# Note that this doesn't use the mask, which would exclude some source image pixels from the
# histogram and cause slight color changes.
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
# Get numpy version
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
# Color correct
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
matched_result = Image.fromarray(np_matched_result, mode='RGB')
# Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0:
nm = np.asarray(pil_init_mask, dtype=np.uint8)
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
pmd = Image.fromarray(nmd, mode='L')
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
else:
blurred_init_mask = pil_init_mask
# Paste original on color-corrected generation (using blurred mask)
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
return matched_result

View File

@ -13,6 +13,7 @@ import gc
import hashlib import hashlib
import psutil import psutil
import transformers import transformers
import os
from sys import getrefcount from sys import getrefcount
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError from omegaconf.errors import ConfigAttributeError
@ -73,7 +74,8 @@ class ModelCache(object):
except Exception as e: except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}') print(f'** model {model_name} could not be loaded: {str(e)}')
print(f'** restoring {self.current_model}') print(f'** restoring {self.current_model}')
return self.get_model(self.current_model) self.get_model(self.current_model)
return None
self.current_model = model_name self.current_model = model_name
self._push_newest_model(model_name) self._push_newest_model(model_name)
@ -84,6 +86,26 @@ class ModelCache(object):
'hash': hash 'hash': hash
} }
def default_model(self) -> str:
'''
Returns the name of the default model, or None
if none is defined.
'''
for model_name in self.config:
if self.config[model_name].get('default',False):
return model_name
return None
def set_default_model(self,model_name:str):
'''
Set the default model. The change will not take
effect until you call model_cache.commit()
'''
assert model_name in self.models,f"unknown model '{model_name}'"
for model in self.models:
self.models[model].pop('default',None)
self.models[model_name]['default'] = True
def list_models(self) -> dict: def list_models(self) -> dict:
''' '''
Return a dict of models in the format: Return a dict of models in the format:
@ -121,12 +143,23 @@ class ModelCache(object):
else: else:
print(line) print(line)
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str: def del_model(self, model_name:str) ->bool:
'''
Delete the named model.
'''
omega = self.config
del omega[model_name]
if model_name in self.stack:
self.stack.remove(model_name)
return True
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
''' '''
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory and a YAML On a successful update, the config will be changed in memory and the
string will be returned. method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing.
''' '''
omega = self.config omega = self.config
# check that all the required fields are present # check that all the required fields are present
@ -139,7 +172,9 @@ class ModelCache(object):
config[field] = model_attributes[field] config[field] = model_attributes[field]
omega[model_name] = config omega[model_name] = config
return OmegaConf.to_yaml(omega) if clobber:
self._invalidate_cached_model(model_name)
return True
def _check_memory(self): def _check_memory(self):
avail_memory = psutil.virtual_memory()[1] avail_memory = psutil.virtual_memory()[1]
@ -159,6 +194,7 @@ class ModelCache(object):
mconfig = self.config[model_name] mconfig = self.config[model_name]
config = mconfig.config config = mconfig.config
weights = mconfig.weights weights = mconfig.weights
vae = mconfig.get('vae',None)
width = mconfig.width width = mconfig.width
height = mconfig.height height = mconfig.height
@ -188,9 +224,17 @@ class ModelCache(object):
else: else:
print(' | Using more accurate float32 precision') print(' | Using more accurate float32 precision')
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
if vae and os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}')
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False)
model.to(self.device) model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
model.cond_stage_model.device = self.device model.cond_stage_model.device = self.device
model.eval() model.eval()
for m in model.modules(): for m in model.modules():
@ -219,6 +263,36 @@ class ModelCache(object):
if self._has_cuda(): if self._has_cuda():
torch.cuda.empty_cache() torch.cuda.empty_cache()
def commit(self,config_file_path:str):
'''
Write current configuration out to the indicated file.
'''
yaml_str = OmegaConf.to_yaml(self.config)
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(self.preamble())
outfile.write(yaml_str)
os.rename(tmpfile,config_file_path)
def preamble(self):
'''
Returns the preamble for the config file.
'''
return '''# This file describes the alternative machine learning models
# available to the dream script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
'''
def _invalidate_cached_model(self,model_name:str):
self.unload_model(model_name)
if model_name in self.stack:
self.stack.remove(model_name)
self.models.pop(model_name,None)
def _model_to_cpu(self,model): def _model_to_cpu(self,model):
if self.device != 'cpu': if self.device != 'cpu':
model.cond_stage_model.device = 'cpu' model.cond_stage_model.device = 'cpu'

View File

@ -38,7 +38,7 @@ class PngWriter:
info = PngImagePlugin.PngInfo() info = PngImagePlugin.PngInfo()
info.add_text('Dream', dream_prompt) info.add_text('Dream', dream_prompt)
if metadata: if metadata:
info.add_text('sd-metadata', json.dumps(metadata)) info.add_text('sd-metadata', json.dumps(metadata))
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level) image.save(path, 'PNG', pnginfo=info, compress_level=compress_level)
return path return path

View File

@ -57,12 +57,13 @@ COMMANDS = (
'--png_compression','-z', '--png_compression','-z',
'--text_mask','-tm', '--text_mask','-tm',
'!fix','!fetch','!replay','!history','!search','!clear', '!fix','!fetch','!replay','!history','!search','!clear',
'!models','!switch','!import_model','!edit_model','!del_model',
'!mask', '!mask',
'!models','!switch','!import_model','!edit_model'
) )
MODEL_COMMANDS = ( MODEL_COMMANDS = (
'!switch', '!switch',
'!edit_model', '!edit_model',
'!del_model',
) )
WEIGHT_COMMANDS = ( WEIGHT_COMMANDS = (
'!import_model', '!import_model',
@ -218,9 +219,24 @@ class Completer(object):
pydoc.pager('\n'.join(lines)) pydoc.pager('\n'.join(lines))
def set_line(self,line)->None: def set_line(self,line)->None:
'''
Set the default string displayed in the next line of input.
'''
self.linebuffer = line self.linebuffer = line
readline.redisplay() readline.redisplay()
def add_model(self,model_name:str)->None:
'''
add a model name to the completion list
'''
self.models.append(model_name)
def del_model(self,model_name:str)->None:
'''
removes a model name from the completion list
'''
self.models.remove(model_name)
def _seed_completions(self, text, state): def _seed_completions(self, text, state):
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text) m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
if m: if m:

View File

@ -35,4 +35,4 @@ realesrgan
git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg

View File

@ -424,6 +424,15 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
completer.add_history(command) completer.add_history(command)
operation = None operation = None
elif command.startswith('!del'):
path = shlex.split(command)
if len(path) < 2:
print('** please provide the name of a model')
else:
del_config(path[1], gen, opt, completer)
completer.add_history(command)
operation = None
elif command.startswith('!fetch'): elif command.startswith('!fetch'):
file_path = command.replace('!fetch','',1).strip() file_path = command.replace('!fetch','',1).strip()
retrieve_dream_command(opt,file_path,completer) retrieve_dream_command(opt,file_path,completer)
@ -484,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
new_config['config'] = input('Configuration file for this model: ') new_config['config'] = input('Configuration file for this model: ')
done = os.path.exists(new_config['config']) done = os.path.exists(new_config['config'])
done = False
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
while not done:
vae = input('VAE autoencoder file for this model [None]: ')
if os.path.exists(vae):
new_config['vae'] = vae
done = True
else:
done = len(vae)==0
completer.complete_extensions(None) completer.complete_extensions(None)
for field in ('width','height'): for field in ('width','height'):
@ -498,9 +517,25 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
except: except:
print('** Please enter a valid integer between 64 and 2048') print('** Please enter a valid integer between 64 and 2048')
if write_config_file(opt.conf, gen, model_name, new_config): make_default = input('Make this the default model? [n] ') in ('y','Y')
gen.set_model(model_name)
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
completer.add_model(model_name)
def del_config(model_name:str, gen, opt, completer):
current_model = gen.model_name
if model_name == current_model:
print("** Can't delete active model. !switch to another model first. **")
return
yaml_str = gen.model_cache.del_model(model_name)
tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(yaml_str)
os.rename(tmpfile,opt.conf)
print(f'** {model_name} deleted')
completer.del_model(model_name)
def edit_config(model_name:str, gen, opt, completer): def edit_config(model_name:str, gen, opt, completer):
config = gen.model_cache.config config = gen.model_cache.config
@ -512,33 +547,46 @@ def edit_config(model_name:str, gen, opt, completer):
conf = config[model_name] conf = config[model_name]
new_config = {} new_config = {}
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae')) completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
for field in ('description', 'weights', 'config', 'width','height'): for field in ('description', 'weights', 'vae', 'config', 'width','height'):
completer.linebuffer = str(conf[field]) if field in conf else '' completer.linebuffer = str(conf[field]) if field in conf else ''
new_value = input(f'{field}: ') new_value = input(f'{field}: ')
new_config[field] = int(new_value) if field in ('width','height') else new_value new_config[field] = int(new_value) if field in ('width','height') else new_value
make_default = input('Make this the default model? [n] ') in ('y','Y')
completer.complete_extensions(None) completer.complete_extensions(None)
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
if write_config_file(opt.conf, gen, model_name, new_config, clobber=True):
gen.set_model(model_name) def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
current_model = gen.model_name
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
op = 'modify' if clobber else 'import' op = 'modify' if clobber else 'import'
print('\n>> New configuration:') print('\n>> New configuration:')
if make_default:
new_config['default'] = True
print(yaml.dump({model_name:new_config})) print(yaml.dump({model_name:new_config}))
if input(f'OK to {op} [n]? ') not in ('y','Y'): if input(f'OK to {op} [n]? ') not in ('y','Y'):
return False return False
try: try:
print('>> Verifying that new model loads...')
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber) yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
assert gen.set_model(model_name) is not None, 'model failed to load'
except AssertionError as e: except AssertionError as e:
print(f'** configuration failed: {str(e)}') print(f'** aborting **')
gen.model_cache.del_model(model_name)
return False return False
if make_default:
print('making this default')
gen.model_cache.set_default_model(model_name)
gen.model_cache.commit(conf_path)
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp') do_switch = input(f'Keep model loaded? [y]')
with open(tmpfile, 'w') as outfile: if len(do_switch)==0 or do_switch[0] in ('y','Y'):
outfile.write(yaml_str) pass
os.rename(tmpfile,conf_path) else:
gen.set_model(current_model)
return True return True
def do_textmask(gen, opt, callback): def do_textmask(gen, opt, callback):
@ -598,7 +646,10 @@ def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command):
original_file = original_file if os.path.exists(original_file) else os.path.join(opt.outdir,original_file) original_file = original_file if os.path.exists(original_file) else os.path.join(opt.outdir,original_file)
new_file = new_file if os.path.exists(new_file) else os.path.join(opt.outdir,new_file) new_file = new_file if os.path.exists(new_file) else os.path.join(opt.outdir,new_file)
meta = retrieve_metadata(original_file)['sd-metadata'] meta = retrieve_metadata(original_file)['sd-metadata']
img_data = meta['image'] if 'image' not in meta:
meta = metadata_dumps(opt,seeds=[opt.seed])['image']
meta['image'] = {}
img_data = meta.get('image')
pp = img_data.get('postprocessing',[]) or [] pp = img_data.get('postprocessing',[]) or []
pp.append( pp.append(
{ {
@ -748,26 +799,38 @@ def retrieve_dream_command(opt,command,completer):
will retrieve and format the dream command used to generate the image, will retrieve and format the dream command used to generate the image,
and pop it into the readline buffer (linux, Mac), or print out a comment and pop it into the readline buffer (linux, Mac), or print out a comment
for cut-and-paste (windows) for cut-and-paste (windows)
Given a wildcard path to a folder with image png files, Given a wildcard path to a folder with image png files,
will retrieve and format the dream command used to generate the images, will retrieve and format the dream command used to generate the images,
and save them to a file commands.txt for further processing and save them to a file commands.txt for further processing
''' '''
if len(command) == 0: if len(command) == 0:
return return
tokens = command.split() tokens = command.split()
if len(tokens) > 1: dir,basename = os.path.split(tokens[0])
outfilepath = tokens[1]
else:
outfilepath = "commands.txt"
file_path = tokens[0]
dir,basename = os.path.split(file_path)
if len(dir) == 0: if len(dir) == 0:
dir = opt.outdir path = os.path.join(opt.outdir,basename)
else:
outdir,outname = os.path.split(outfilepath) path = tokens[0]
if len(outdir) == 0:
outfilepath = os.path.join(dir,outname) if len(tokens) > 1:
return write_commands(opt, path, tokens[1])
cmd = ''
try:
cmd = dream_cmd_from_png(path)
except OSError:
print(f'## {tokens[0]}: file could not be read')
except (KeyError, AttributeError, IndexError):
print(f'## {tokens[0]}: file has no metadata')
except:
print(f'## {tokens[0]}: file could not be processed')
if len(cmd)>0:
completer.set_line(cmd)
def write_commands(opt, file_path:str, outfilepath:str):
dir,basename = os.path.split(file_path)
try: try:
paths = list(Path(dir).glob(basename)) paths = list(Path(dir).glob(basename))
except ValueError: except ValueError:
@ -775,28 +838,24 @@ def retrieve_dream_command(opt,command,completer):
return return
commands = [] commands = []
cmd = None
for path in paths: for path in paths:
try: try:
cmd = dream_cmd_from_png(path) cmd = dream_cmd_from_png(path)
except OSError:
print(f'## {path}: file could not be read')
continue
except (KeyError, AttributeError, IndexError): except (KeyError, AttributeError, IndexError):
print(f'## {path}: file has no metadata') print(f'## {path}: file has no metadata')
continue
except: except:
print(f'## {path}: file could not be processed') print(f'## {path}: file could not be processed')
continue if cmd:
commands.append(f'# {path}')
commands.append(f'# {path}') commands.append(cmd)
commands.append(cmd) if len(commands)>0:
dir,basename = os.path.split(outfilepath)
with open(outfilepath, 'w', encoding='utf-8') as f: if len(dir)==0:
f.write('\n'.join(commands)) outfilepath = os.path.join(opt.outdir,basename)
print(f'>> File {outfilepath} with commands created') with open(outfilepath, 'w', encoding='utf-8') as f:
f.write('\n'.join(commands))
if len(commands) == 2: print(f'>> File {outfilepath} with commands created')
completer.set_line(commands[1])
###################################### ######################################