mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into fix-prompts
This commit is contained in:
commit
194c8e1c2e
@ -1,20 +1,22 @@
|
|||||||
# This file describes the alternative machine learning models
|
# This file describes the alternative machine learning models
|
||||||
# available to the dream script.
|
# available to the dream script.
|
||||||
#
|
#
|
||||||
# To add a new model, follow the examples below. Each
|
# To add a new model, follow the examples below. Each
|
||||||
# model requires a model config file, a weights file,
|
# model requires a model config file, a weights file,
|
||||||
# and the width and height of the images it
|
# and the width and height of the images it
|
||||||
# was trained on.
|
# was trained on.
|
||||||
|
|
||||||
laion400m:
|
|
||||||
config: configs/latent-diffusion/txt2img-1p4B-eval.yaml
|
|
||||||
weights: models/ldm/text2img-large/model.ckpt
|
|
||||||
description: Latent Diffusion LAION400M model
|
|
||||||
width: 256
|
|
||||||
height: 256
|
|
||||||
stable-diffusion-1.4:
|
stable-diffusion-1.4:
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
description: Stable Diffusion inference model version 1.4
|
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
width: 512
|
description: Stable Diffusion inference model version 1.4
|
||||||
height: 512
|
default: true
|
||||||
|
width: 512
|
||||||
|
height: 512
|
||||||
|
stable-diffusion-1.5:
|
||||||
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
|
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||||
|
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
|
description: Stable Diffusion inference model version 1.5
|
||||||
|
width: 512
|
||||||
|
height: 512
|
||||||
|
@ -8,7 +8,7 @@ hide:
|
|||||||
|
|
||||||
## **Interactive Command Line Interface**
|
## **Interactive Command Line Interface**
|
||||||
|
|
||||||
The `invoke.py` script, located in `scripts/dream.py`, provides an interactive
|
The `invoke.py` script, located in `scripts/`, provides an interactive
|
||||||
interface to image generation similar to the "invoke mothership" bot that Stable
|
interface to image generation similar to the "invoke mothership" bot that Stable
|
||||||
AI provided on its Discord server.
|
AI provided on its Discord server.
|
||||||
|
|
||||||
|
@ -81,15 +81,18 @@ text2mask feature. The syntax is `!mask /path/to/image.png -tm <text>
|
|||||||
It will generate three files:
|
It will generate three files:
|
||||||
|
|
||||||
- The image with the selected area highlighted.
|
- The image with the selected area highlighted.
|
||||||
|
- it will be named XXXXX.<imagename>.<prompt>.selected.png
|
||||||
- The image with the un-selected area highlighted.
|
- The image with the un-selected area highlighted.
|
||||||
|
- it will be named XXXXX.<imagename>.<prompt>.deselected.png
|
||||||
- The image with the selected area converted into a black and white
|
- The image with the selected area converted into a black and white
|
||||||
image according to the threshold level.
|
image according to the threshold level
|
||||||
|
- it will be named XXXXX.<imagename>.<prompt>.masked.png
|
||||||
|
|
||||||
Note that none of these images are intended to be used as the mask
|
The `.masked.png` file can then be directly passed to the `invoke>`
|
||||||
passed to invoke via `-M` and may give unexpected results if you try
|
prompt in the CLI via the `-M` argument. Do not attempt this with
|
||||||
to use them this way. Instead, use `!mask` for testing that you are
|
the `selected.png` or `deselected.png` files, as they contain some
|
||||||
selecting the right mask area, and then do inpainting using the
|
transparency throughout the image and will not produce the desired
|
||||||
best selection term and threshold.
|
results.
|
||||||
|
|
||||||
Here is an example of how `!mask` works:
|
Here is an example of how `!mask` works:
|
||||||
|
|
||||||
@ -120,7 +123,7 @@ It looks like we selected the hair pretty well at the 0.5 threshold
|
|||||||
let's have some fun:
|
let's have some fun:
|
||||||
|
|
||||||
```
|
```
|
||||||
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair 0.5 -C20
|
invoke> medusa with cobras -I ./test-pictures/curly.png -M 000019.curly.hair.masked.png -C20
|
||||||
>> loaded input image of size 512x512 from ./test-pictures/curly.png
|
>> loaded input image of size 512x512 from ./test-pictures/curly.png
|
||||||
...
|
...
|
||||||
Outputs:
|
Outputs:
|
||||||
@ -129,6 +132,13 @@ Outputs:
|
|||||||
|
|
||||||
<img src="../assets/inpainting/000024.801380492.png">
|
<img src="../assets/inpainting/000024.801380492.png">
|
||||||
|
|
||||||
|
You can also skip the `!mask` creation step and just select the masked
|
||||||
|
|
||||||
|
region directly:
|
||||||
|
```
|
||||||
|
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair -C20
|
||||||
|
```
|
||||||
|
|
||||||
### Inpainting is not changing the masked region enough!
|
### Inpainting is not changing the masked region enough!
|
||||||
|
|
||||||
One of the things to understand about how inpainting works is that it
|
One of the things to understand about how inpainting works is that it
|
||||||
|
121
ldm/generate.py
121
ldm/generate.py
@ -56,23 +56,8 @@ torch.randint_like = fix_func(torch.randint_like)
|
|||||||
torch.bernoulli = fix_func(torch.bernoulli)
|
torch.bernoulli = fix_func(torch.bernoulli)
|
||||||
torch.multinomial = fix_func(torch.multinomial)
|
torch.multinomial = fix_func(torch.multinomial)
|
||||||
|
|
||||||
def fix_func(orig):
|
# this is fallback model in case no default is defined
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
||||||
def new_func(*args, **kw):
|
|
||||||
device = kw.get("device", "mps")
|
|
||||||
kw["device"]="cpu"
|
|
||||||
return orig(*args, **kw).to(device)
|
|
||||||
return new_func
|
|
||||||
return orig
|
|
||||||
|
|
||||||
torch.rand = fix_func(torch.rand)
|
|
||||||
torch.rand_like = fix_func(torch.rand_like)
|
|
||||||
torch.randn = fix_func(torch.randn)
|
|
||||||
torch.randn_like = fix_func(torch.randn_like)
|
|
||||||
torch.randint = fix_func(torch.randint)
|
|
||||||
torch.randint_like = fix_func(torch.randint_like)
|
|
||||||
torch.bernoulli = fix_func(torch.bernoulli)
|
|
||||||
torch.multinomial = fix_func(torch.multinomial)
|
|
||||||
|
|
||||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
|
|
||||||
@ -126,12 +111,13 @@ still work.
|
|||||||
The full list of arguments to Generate() are:
|
The full list of arguments to Generate() are:
|
||||||
gr = Generate(
|
gr = Generate(
|
||||||
# these values are set once and shouldn't be changed
|
# these values are set once and shouldn't be changed
|
||||||
conf = path to configuration file ('configs/models.yaml')
|
conf:str = path to configuration file ('configs/models.yaml')
|
||||||
model = symbolic name of the model in the configuration file
|
model:str = symbolic name of the model in the configuration file
|
||||||
precision = float precision to be used
|
precision:float = float precision to be used
|
||||||
|
safety_checker:bool = activate safety checker [False]
|
||||||
|
|
||||||
# this value is sticky and maintained between generation calls
|
# this value is sticky and maintained between generation calls
|
||||||
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
sampler_name:str = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||||
|
|
||||||
# these are deprecated - use conf and model instead
|
# these are deprecated - use conf and model instead
|
||||||
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
|
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
|
||||||
@ -148,7 +134,7 @@ class Generate:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model = 'stable-diffusion-1.4',
|
model = None,
|
||||||
conf = 'configs/models.yaml',
|
conf = 'configs/models.yaml',
|
||||||
embedding_path = None,
|
embedding_path = None,
|
||||||
sampler_name = 'k_lms',
|
sampler_name = 'k_lms',
|
||||||
@ -164,7 +150,6 @@ class Generate:
|
|||||||
free_gpu_mem=False,
|
free_gpu_mem=False,
|
||||||
):
|
):
|
||||||
mconfig = OmegaConf.load(conf)
|
mconfig = OmegaConf.load(conf)
|
||||||
self.model_name = model
|
|
||||||
self.height = None
|
self.height = None
|
||||||
self.width = None
|
self.width = None
|
||||||
self.model_cache = None
|
self.model_cache = None
|
||||||
@ -211,6 +196,7 @@ class Generate:
|
|||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||||
|
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
||||||
|
|
||||||
# for VRAM usage statistics
|
# for VRAM usage statistics
|
||||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||||
@ -287,6 +273,8 @@ class Generate:
|
|||||||
upscale = None,
|
upscale = None,
|
||||||
# this is specific to inpainting and causes more extreme inpainting
|
# this is specific to inpainting and causes more extreme inpainting
|
||||||
inpaint_replace = 0.0,
|
inpaint_replace = 0.0,
|
||||||
|
# This will help match inpainted areas to the original image more smoothly
|
||||||
|
mask_blur_radius: int = 8,
|
||||||
# Set this True to handle KeyboardInterrupt internally
|
# Set this True to handle KeyboardInterrupt internally
|
||||||
catch_interrupts = False,
|
catch_interrupts = False,
|
||||||
hires_fix = False,
|
hires_fix = False,
|
||||||
@ -407,7 +395,7 @@ class Generate:
|
|||||||
log_tokens =self.log_tokenization
|
log_tokens =self.log_tokenization
|
||||||
)
|
)
|
||||||
|
|
||||||
init_image,mask_image = self._make_images(
|
init_image, mask_image = self._make_images(
|
||||||
init_img,
|
init_img,
|
||||||
init_mask,
|
init_mask,
|
||||||
width,
|
width,
|
||||||
@ -454,6 +442,7 @@ class Generate:
|
|||||||
embiggen=embiggen,
|
embiggen=embiggen,
|
||||||
embiggen_tiles=embiggen_tiles,
|
embiggen_tiles=embiggen_tiles,
|
||||||
inpaint_replace=inpaint_replace,
|
inpaint_replace=inpaint_replace,
|
||||||
|
mask_blur_radius=mask_blur_radius
|
||||||
)
|
)
|
||||||
|
|
||||||
if init_color:
|
if init_color:
|
||||||
@ -572,16 +561,19 @@ class Generate:
|
|||||||
from ldm.invoke.restoration.outcrop import Outcrop
|
from ldm.invoke.restoration.outcrop import Outcrop
|
||||||
extend_instructions = {}
|
extend_instructions = {}
|
||||||
for direction,pixels in _pairwise(opt.outcrop):
|
for direction,pixels in _pairwise(opt.outcrop):
|
||||||
extend_instructions[direction]=int(pixels)
|
try:
|
||||||
|
extend_instructions[direction]=int(pixels)
|
||||||
restorer = Outcrop(image,self,)
|
except ValueError:
|
||||||
return restorer.process (
|
print(f'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"')
|
||||||
extend_instructions,
|
if len(extend_instructions)>0:
|
||||||
opt = opt,
|
restorer = Outcrop(image,self,)
|
||||||
orig_opt = args,
|
return restorer.process (
|
||||||
image_callback = callback,
|
extend_instructions,
|
||||||
prefix = prefix,
|
opt = opt,
|
||||||
)
|
orig_opt = args,
|
||||||
|
image_callback = callback,
|
||||||
|
prefix = prefix,
|
||||||
|
)
|
||||||
|
|
||||||
elif tool == 'embiggen':
|
elif tool == 'embiggen':
|
||||||
# fetch the metadata from the image
|
# fetch the metadata from the image
|
||||||
@ -645,23 +637,22 @@ class Generate:
|
|||||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||||
if self._has_transparency(image):
|
if self._has_transparency(image):
|
||||||
self._transparency_check_and_warning(image, mask)
|
self._transparency_check_and_warning(image, mask)
|
||||||
# this returns a torch tensor
|
|
||||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||||
|
|
||||||
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||||
self.size_matters = False
|
self.size_matters = False
|
||||||
|
|
||||||
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
init_image = self._create_init_image(image,width,height,fit=fit)
|
||||||
|
|
||||||
if mask:
|
if mask:
|
||||||
mask_image = self._load_img(mask) # this returns an Image
|
mask_image = self._load_img(mask)
|
||||||
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
|
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
|
||||||
|
|
||||||
elif text_mask:
|
elif text_mask:
|
||||||
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
||||||
|
|
||||||
return init_image, init_mask
|
return init_image,init_mask
|
||||||
|
|
||||||
def _make_base(self):
|
def _make_base(self):
|
||||||
if not self.generators.get('base'):
|
if not self.generators.get('base'):
|
||||||
@ -717,8 +708,7 @@ class Generate:
|
|||||||
|
|
||||||
model_data = self.model_cache.get_model(model_name)
|
model_data = self.model_cache.get_model(model_name)
|
||||||
if model_data is None or len(model_data) == 0:
|
if model_data is None or len(model_data) == 0:
|
||||||
print(f'** Model switch failed **')
|
return None
|
||||||
return self.model
|
|
||||||
|
|
||||||
self.model = model_data['model']
|
self.model = model_data['model']
|
||||||
self.width = model_data['width']
|
self.width = model_data['width']
|
||||||
@ -879,46 +869,31 @@ class Generate:
|
|||||||
|
|
||||||
def _create_init_image(self, image, width, height, fit=True):
|
def _create_init_image(self, image, width, height, fit=True):
|
||||||
image = image.convert('RGB')
|
image = image.convert('RGB')
|
||||||
if fit:
|
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
|
||||||
image = self._fit_image(image, (width, height))
|
return image
|
||||||
else:
|
|
||||||
image = self._squeeze_image(image)
|
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
|
||||||
image = torch.from_numpy(image)
|
|
||||||
image = 2.0 * image - 1.0
|
|
||||||
return image.to(self.device)
|
|
||||||
|
|
||||||
def _create_init_mask(self, image, width, height, fit=True):
|
def _create_init_mask(self, image, width, height, fit=True):
|
||||||
# convert into a black/white mask
|
# convert into a black/white mask
|
||||||
image = self._image_to_mask(image)
|
image = self._image_to_mask(image)
|
||||||
image = image.convert('RGB')
|
image = image.convert('RGB')
|
||||||
|
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
|
||||||
# now we adjust the size
|
return image
|
||||||
if fit:
|
|
||||||
image = self._fit_image(image, (width, height))
|
|
||||||
else:
|
|
||||||
image = self._squeeze_image(image)
|
|
||||||
image = image.resize((image.width//downsampling, image.height //
|
|
||||||
downsampling), resample=Image.Resampling.NEAREST)
|
|
||||||
image = np.array(image)
|
|
||||||
image = image.astype(np.float32) / 255.0
|
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
|
||||||
image = torch.from_numpy(image)
|
|
||||||
return image.to(self.device)
|
|
||||||
|
|
||||||
# The mask is expected to have the region to be inpainted
|
# The mask is expected to have the region to be inpainted
|
||||||
# with alpha transparency. It converts it into a black/white
|
# with alpha transparency. It converts it into a black/white
|
||||||
# image with the transparent part black.
|
# image with the transparent part black.
|
||||||
def _image_to_mask(self, mask_image, invert=False) -> Image:
|
def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image:
|
||||||
# Obtain the mask from the transparency channel
|
# Obtain the mask from the transparency channel
|
||||||
mask = Image.new(mode="L", size=mask_image.size, color=255)
|
if mask_image.mode == 'L':
|
||||||
mask.putdata(mask_image.getdata(band=3))
|
mask = mask_image
|
||||||
|
else:
|
||||||
|
# Obtain the mask from the transparency channel
|
||||||
|
mask = Image.new(mode="L", size=mask_image.size, color=255)
|
||||||
|
mask.putdata(mask_image.getdata(band=3))
|
||||||
if invert:
|
if invert:
|
||||||
mask = ImageOps.invert(mask)
|
mask = ImageOps.invert(mask)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
# TODO: The latter part of this method repeats code from _create_init_mask()
|
|
||||||
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
|
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
|
||||||
prompt = text_mask[0]
|
prompt = text_mask[0]
|
||||||
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
|
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
|
||||||
@ -928,18 +903,8 @@ class Generate:
|
|||||||
segmented = self.txt2mask.segment(image, prompt)
|
segmented = self.txt2mask.segment(image, prompt)
|
||||||
mask = segmented.to_mask(float(confidence_level))
|
mask = segmented.to_mask(float(confidence_level))
|
||||||
mask = mask.convert('RGB')
|
mask = mask.convert('RGB')
|
||||||
# now we adjust the size
|
mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
|
||||||
if fit:
|
return mask
|
||||||
mask = self._fit_image(mask, (width, height))
|
|
||||||
else:
|
|
||||||
mask = self._squeeze_image(mask)
|
|
||||||
mask = mask.resize((mask.width//downsampling, mask.height //
|
|
||||||
downsampling), resample=Image.Resampling.NEAREST)
|
|
||||||
mask = np.array(mask)
|
|
||||||
mask = mask.astype(np.float32) / 255.0
|
|
||||||
mask = mask[None].transpose(0, 3, 1, 2)
|
|
||||||
mask = torch.from_numpy(mask)
|
|
||||||
return mask.to(self.device)
|
|
||||||
|
|
||||||
def _has_transparency(self, image):
|
def _has_transparency(self, image):
|
||||||
if image.info.get("transparency", None) is not None:
|
if image.info.get("transparency", None) is not None:
|
||||||
|
@ -113,8 +113,8 @@ PRECISION_CHOICES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
# is there a way to pick this up during git commits?
|
# is there a way to pick this up during git commits?
|
||||||
APP_ID = 'lstein/stable-diffusion'
|
APP_ID = 'invoke-ai/InvokeAI'
|
||||||
APP_VERSION = 'v1.15'
|
APP_VERSION = 'v2.02'
|
||||||
|
|
||||||
class ArgFormatter(argparse.RawTextHelpFormatter):
|
class ArgFormatter(argparse.RawTextHelpFormatter):
|
||||||
# use defined argument order to display usage
|
# use defined argument order to display usage
|
||||||
@ -172,6 +172,7 @@ class Args(object):
|
|||||||
command = cmd_string.replace("'", "\\'")
|
command = cmd_string.replace("'", "\\'")
|
||||||
try:
|
try:
|
||||||
elements = shlex.split(command)
|
elements = shlex.split(command)
|
||||||
|
elements = [x.replace("\\'","'") for x in elements]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
import sys, traceback
|
import sys, traceback
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
@ -366,17 +367,16 @@ class Args(object):
|
|||||||
deprecated_group.add_argument('--laion400m')
|
deprecated_group.add_argument('--laion400m')
|
||||||
deprecated_group.add_argument('--weights') # deprecated
|
deprecated_group.add_argument('--weights') # deprecated
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--conf',
|
'--config',
|
||||||
'-c',
|
'-c',
|
||||||
'-conf',
|
'-config',
|
||||||
dest='conf',
|
dest='conf',
|
||||||
default='./configs/models.yaml',
|
default='./configs/models.yaml',
|
||||||
help='Path to configuration file for alternate models.',
|
help='Path to configuration file for alternate models.',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--model',
|
'--model',
|
||||||
default='stable-diffusion-1.4',
|
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||||
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
|
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--png_compression','-z',
|
'--png_compression','-z',
|
||||||
@ -529,7 +529,7 @@ class Args(object):
|
|||||||
formatter_class=ArgFormatter,
|
formatter_class=ArgFormatter,
|
||||||
description=
|
description=
|
||||||
"""
|
"""
|
||||||
*Image generation:*
|
*Image generation*
|
||||||
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
|
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
|
||||||
|
|
||||||
*postprocessing*
|
*postprocessing*
|
||||||
@ -544,6 +544,13 @@ class Args(object):
|
|||||||
!history lists all the commands issued during the current session.
|
!history lists all the commands issued during the current session.
|
||||||
|
|
||||||
!NN retrieves the NNth command from the history
|
!NN retrieves the NNth command from the history
|
||||||
|
|
||||||
|
*Model manipulation*
|
||||||
|
!models -- list models in configs/models.yaml
|
||||||
|
!switch <model_name> -- switch to model named <model_name>
|
||||||
|
!import_model path/to/weights/file.ckpt -- adds a model to your config
|
||||||
|
!edit_model <model_name> -- edit a model's description
|
||||||
|
!del_model <model_name> -- delete a model
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
render_group = parser.add_argument_group('General rendering')
|
render_group = parser.add_argument_group('General rendering')
|
||||||
@ -840,7 +847,7 @@ def metadata_dumps(opt,
|
|||||||
# remove any image keys not mentioned in RFC #266
|
# remove any image keys not mentioned in RFC #266
|
||||||
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
|
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
|
||||||
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength',
|
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength',
|
||||||
'init_img','init_mask']
|
'init_img','init_mask','facetool','facetool_strength','upscale']
|
||||||
|
|
||||||
rfc_dict ={}
|
rfc_dict ={}
|
||||||
|
|
||||||
@ -924,7 +931,7 @@ def metadata_loads(metadata) -> list:
|
|||||||
for image in images:
|
for image in images:
|
||||||
# repack the prompt and variations
|
# repack the prompt and variations
|
||||||
if 'prompt' in image:
|
if 'prompt' in image:
|
||||||
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']])
|
image['prompt'] = repack_prompt(image['prompt'])
|
||||||
if 'variations' in image:
|
if 'variations' in image:
|
||||||
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
|
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
|
||||||
# fix a bit of semantic drift here
|
# fix a bit of semantic drift here
|
||||||
@ -932,12 +939,19 @@ def metadata_loads(metadata) -> list:
|
|||||||
opt = Args()
|
opt = Args()
|
||||||
opt._cmd_switches = Namespace(**image)
|
opt._cmd_switches = Namespace(**image)
|
||||||
results.append(opt)
|
results.append(opt)
|
||||||
except KeyError as e:
|
except Exception as e:
|
||||||
import sys, traceback
|
import sys, traceback
|
||||||
print('>> badly-formatted metadata',file=sys.stderr)
|
print('>> could not read metadata',file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def repack_prompt(prompt_list:list)->str:
|
||||||
|
# in the common case of no weighting syntax, just return the prompt as is
|
||||||
|
if len(prompt_list) > 1:
|
||||||
|
return ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in prompt_list])
|
||||||
|
else:
|
||||||
|
return prompt_list[0]['prompt']
|
||||||
|
|
||||||
# image can either be a file path on disk or a base64-encoded
|
# image can either be a file path on disk or a base64-encoded
|
||||||
# representation of the file's contents
|
# representation of the file's contents
|
||||||
def calculate_init_img_hash(image_string):
|
def calculate_init_img_hash(image_string):
|
||||||
@ -967,17 +981,17 @@ def sha256(path):
|
|||||||
return sha.hexdigest()
|
return sha.hexdigest()
|
||||||
|
|
||||||
def legacy_metadata_load(meta,pathname) -> Args:
|
def legacy_metadata_load(meta,pathname) -> Args:
|
||||||
|
opt = Args()
|
||||||
if 'Dream' in meta and len(meta['Dream']) > 0:
|
if 'Dream' in meta and len(meta['Dream']) > 0:
|
||||||
dream_prompt = meta['Dream']
|
dream_prompt = meta['Dream']
|
||||||
opt = Args()
|
|
||||||
opt.parse_cmd(dream_prompt)
|
opt.parse_cmd(dream_prompt)
|
||||||
return opt
|
|
||||||
else: # if nothing else, we can get the seed
|
else: # if nothing else, we can get the seed
|
||||||
match = re.search('\d+\.(\d+)',pathname)
|
match = re.search('\d+\.(\d+)',pathname)
|
||||||
if match:
|
if match:
|
||||||
seed = match.groups()[0]
|
seed = match.groups()[0]
|
||||||
opt = Args()
|
|
||||||
opt.seed = seed
|
opt.seed = seed
|
||||||
return opt
|
else:
|
||||||
return None
|
opt.prompt = ''
|
||||||
|
opt.seed = 0
|
||||||
|
return opt
|
||||||
|
|
||||||
|
@ -4,9 +4,12 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ldm.invoke.devices import choose_autocast
|
import PIL
|
||||||
from ldm.invoke.generator.base import Generator
|
from torch import Tensor
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from PIL import Image
|
||||||
|
from ldm.invoke.devices import choose_autocast
|
||||||
|
from ldm.invoke.generator.base import Generator
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
|
||||||
class Img2Img(Generator):
|
class Img2Img(Generator):
|
||||||
@ -26,6 +29,9 @@ class Img2Img(Generator):
|
|||||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
|
init_image = self._image_to_tensor(init_image)
|
||||||
|
|
||||||
scope = choose_autocast(self.precision)
|
scope = choose_autocast(self.precision)
|
||||||
with scope(self.model.device.type):
|
with scope(self.model.device.type):
|
||||||
self.init_latent = self.model.get_first_stage_encoding(
|
self.init_latent = self.model.get_first_stage_encoding(
|
||||||
@ -71,3 +77,11 @@ class Img2Img(Generator):
|
|||||||
shape = init_latent.shape
|
shape = init_latent.shape
|
||||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
if normalize:
|
||||||
|
image = 2.0 * image - 1.0
|
||||||
|
return image.to(self.model.device)
|
||||||
|
@ -3,27 +3,55 @@ ldm.invoke.generator.inpaint descends from ldm.invoke.generator
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import cv2 as cv
|
||||||
|
import PIL
|
||||||
|
from PIL import Image, ImageFilter
|
||||||
|
from skimage.exposure.histogram_matching import match_histograms
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.invoke.devices import choose_autocast
|
from ldm.invoke.devices import choose_autocast
|
||||||
from ldm.invoke.generator.img2img import Img2Img
|
from ldm.invoke.generator.img2img import Img2Img
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.ksampler import KSampler
|
from ldm.models.diffusion.ksampler import KSampler
|
||||||
|
from ldm.invoke.generator.base import downsampling
|
||||||
|
|
||||||
class Inpaint(Img2Img):
|
class Inpaint(Img2Img):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
|
self.pil_image = None
|
||||||
|
self.pil_mask = None
|
||||||
|
self.mask_blur_radius = 0
|
||||||
super().__init__(model, precision)
|
super().__init__(model, precision)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
conditioning,init_image,mask_image,strength,
|
conditioning,init_image,mask_image,strength,
|
||||||
step_callback=None,inpaint_replace=False,**kwargs):
|
mask_blur_radius: int = 8,
|
||||||
|
step_callback=None,inpaint_replace=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and
|
Returns a function returning an image derived from the prompt and
|
||||||
the initial image + mask. Return value depends on the seed at
|
the initial image + mask. Return value depends on the seed at
|
||||||
the time you call it. kwargs are 'init_latent' and 'strength'
|
the time you call it. kwargs are 'init_latent' and 'strength'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
|
self.pil_image = init_image
|
||||||
|
init_image = self._image_to_tensor(init_image)
|
||||||
|
|
||||||
|
if isinstance(mask_image, PIL.Image.Image):
|
||||||
|
self.pil_mask = mask_image
|
||||||
|
mask_image = mask_image.resize(
|
||||||
|
(
|
||||||
|
mask_image.width // downsampling,
|
||||||
|
mask_image.height // downsampling
|
||||||
|
),
|
||||||
|
resample=Image.Resampling.NEAREST
|
||||||
|
)
|
||||||
|
mask_image = self._image_to_tensor(mask_image,normalize=False)
|
||||||
|
|
||||||
|
self.mask_blur_radius = mask_blur_radius
|
||||||
|
|
||||||
# klms samplers not supported yet, so ignore previous sampler
|
# klms samplers not supported yet, so ignore previous sampler
|
||||||
if isinstance(sampler,KSampler):
|
if isinstance(sampler,KSampler):
|
||||||
print(
|
print(
|
||||||
@ -78,10 +106,50 @@ class Inpaint(Img2Img):
|
|||||||
mask = mask_image,
|
mask = mask_image,
|
||||||
init_latent = self.init_latent
|
init_latent = self.init_latent
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.sample_to_image(samples)
|
return self.sample_to_image(samples)
|
||||||
|
|
||||||
return make_image
|
return make_image
|
||||||
|
|
||||||
|
def sample_to_image(self, samples)->Image.Image:
|
||||||
|
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||||
|
|
||||||
|
if self.pil_image is None or self.pil_mask is None:
|
||||||
|
return gen_result
|
||||||
|
|
||||||
|
pil_mask = self.pil_mask
|
||||||
|
pil_image = self.pil_image
|
||||||
|
mask_blur_radius = self.mask_blur_radius
|
||||||
|
|
||||||
|
# Get the original alpha channel of the mask if there is one.
|
||||||
|
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||||
|
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L')
|
||||||
|
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||||
|
|
||||||
|
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||||
|
# Note that this doesn't use the mask, which would exclude some source image pixels from the
|
||||||
|
# histogram and cause slight color changes.
|
||||||
|
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
|
||||||
|
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
|
||||||
|
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
|
||||||
|
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
|
||||||
|
|
||||||
|
# Get numpy version
|
||||||
|
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Color correct
|
||||||
|
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
|
||||||
|
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||||
|
|
||||||
|
# Blur the mask out (into init image) by specified amount
|
||||||
|
if mask_blur_radius > 0:
|
||||||
|
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||||
|
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||||
|
pmd = Image.fromarray(nmd, mode='L')
|
||||||
|
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||||
|
else:
|
||||||
|
blurred_init_mask = pil_init_mask
|
||||||
|
|
||||||
|
# Paste original on color-corrected generation (using blurred mask)
|
||||||
|
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
|
||||||
|
return matched_result
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ import gc
|
|||||||
import hashlib
|
import hashlib
|
||||||
import psutil
|
import psutil
|
||||||
import transformers
|
import transformers
|
||||||
|
import os
|
||||||
from sys import getrefcount
|
from sys import getrefcount
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.errors import ConfigAttributeError
|
from omegaconf.errors import ConfigAttributeError
|
||||||
@ -73,7 +74,8 @@ class ModelCache(object):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||||
print(f'** restoring {self.current_model}')
|
print(f'** restoring {self.current_model}')
|
||||||
return self.get_model(self.current_model)
|
self.get_model(self.current_model)
|
||||||
|
return None
|
||||||
|
|
||||||
self.current_model = model_name
|
self.current_model = model_name
|
||||||
self._push_newest_model(model_name)
|
self._push_newest_model(model_name)
|
||||||
@ -84,6 +86,26 @@ class ModelCache(object):
|
|||||||
'hash': hash
|
'hash': hash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def default_model(self) -> str:
|
||||||
|
'''
|
||||||
|
Returns the name of the default model, or None
|
||||||
|
if none is defined.
|
||||||
|
'''
|
||||||
|
for model_name in self.config:
|
||||||
|
if self.config[model_name].get('default',False):
|
||||||
|
return model_name
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_default_model(self,model_name:str):
|
||||||
|
'''
|
||||||
|
Set the default model. The change will not take
|
||||||
|
effect until you call model_cache.commit()
|
||||||
|
'''
|
||||||
|
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||||
|
for model in self.models:
|
||||||
|
self.models[model].pop('default',None)
|
||||||
|
self.models[model_name]['default'] = True
|
||||||
|
|
||||||
def list_models(self) -> dict:
|
def list_models(self) -> dict:
|
||||||
'''
|
'''
|
||||||
Return a dict of models in the format:
|
Return a dict of models in the format:
|
||||||
@ -121,12 +143,23 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
|
def del_model(self, model_name:str) ->bool:
|
||||||
|
'''
|
||||||
|
Delete the named model.
|
||||||
|
'''
|
||||||
|
omega = self.config
|
||||||
|
del omega[model_name]
|
||||||
|
if model_name in self.stack:
|
||||||
|
self.stack.remove(model_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
|
||||||
'''
|
'''
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
On a successful update, the config will be changed in memory and a YAML
|
On a successful update, the config will be changed in memory and the
|
||||||
string will be returned.
|
method will return True. Will fail with an assertion error if provided
|
||||||
|
attributes are incorrect or the model name is missing.
|
||||||
'''
|
'''
|
||||||
omega = self.config
|
omega = self.config
|
||||||
# check that all the required fields are present
|
# check that all the required fields are present
|
||||||
@ -139,7 +172,9 @@ class ModelCache(object):
|
|||||||
config[field] = model_attributes[field]
|
config[field] = model_attributes[field]
|
||||||
|
|
||||||
omega[model_name] = config
|
omega[model_name] = config
|
||||||
return OmegaConf.to_yaml(omega)
|
if clobber:
|
||||||
|
self._invalidate_cached_model(model_name)
|
||||||
|
return True
|
||||||
|
|
||||||
def _check_memory(self):
|
def _check_memory(self):
|
||||||
avail_memory = psutil.virtual_memory()[1]
|
avail_memory = psutil.virtual_memory()[1]
|
||||||
@ -159,6 +194,7 @@ class ModelCache(object):
|
|||||||
mconfig = self.config[model_name]
|
mconfig = self.config[model_name]
|
||||||
config = mconfig.config
|
config = mconfig.config
|
||||||
weights = mconfig.weights
|
weights = mconfig.weights
|
||||||
|
vae = mconfig.get('vae',None)
|
||||||
width = mconfig.width
|
width = mconfig.width
|
||||||
height = mconfig.height
|
height = mconfig.height
|
||||||
|
|
||||||
@ -188,9 +224,17 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
print(' | Using more accurate float32 precision')
|
print(' | Using more accurate float32 precision')
|
||||||
|
|
||||||
|
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||||
|
if vae and os.path.exists(vae):
|
||||||
|
print(f' | Loading VAE weights from: {vae}')
|
||||||
|
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||||
|
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||||
|
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||||
|
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||||
model.cond_stage_model.device = self.device
|
model.cond_stage_model.device = self.device
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
@ -219,6 +263,36 @@ class ModelCache(object):
|
|||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def commit(self,config_file_path:str):
|
||||||
|
'''
|
||||||
|
Write current configuration out to the indicated file.
|
||||||
|
'''
|
||||||
|
yaml_str = OmegaConf.to_yaml(self.config)
|
||||||
|
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
|
||||||
|
with open(tmpfile, 'w') as outfile:
|
||||||
|
outfile.write(self.preamble())
|
||||||
|
outfile.write(yaml_str)
|
||||||
|
os.rename(tmpfile,config_file_path)
|
||||||
|
|
||||||
|
def preamble(self):
|
||||||
|
'''
|
||||||
|
Returns the preamble for the config file.
|
||||||
|
'''
|
||||||
|
return '''# This file describes the alternative machine learning models
|
||||||
|
# available to the dream script.
|
||||||
|
#
|
||||||
|
# To add a new model, follow the examples below. Each
|
||||||
|
# model requires a model config file, a weights file,
|
||||||
|
# and the width and height of the images it
|
||||||
|
# was trained on.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def _invalidate_cached_model(self,model_name:str):
|
||||||
|
self.unload_model(model_name)
|
||||||
|
if model_name in self.stack:
|
||||||
|
self.stack.remove(model_name)
|
||||||
|
self.models.pop(model_name,None)
|
||||||
|
|
||||||
def _model_to_cpu(self,model):
|
def _model_to_cpu(self,model):
|
||||||
if self.device != 'cpu':
|
if self.device != 'cpu':
|
||||||
model.cond_stage_model.device = 'cpu'
|
model.cond_stage_model.device = 'cpu'
|
||||||
|
@ -38,7 +38,7 @@ class PngWriter:
|
|||||||
info = PngImagePlugin.PngInfo()
|
info = PngImagePlugin.PngInfo()
|
||||||
info.add_text('Dream', dream_prompt)
|
info.add_text('Dream', dream_prompt)
|
||||||
if metadata:
|
if metadata:
|
||||||
info.add_text('sd-metadata', json.dumps(metadata))
|
info.add_text('sd-metadata', json.dumps(metadata))
|
||||||
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level)
|
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
@ -57,12 +57,13 @@ COMMANDS = (
|
|||||||
'--png_compression','-z',
|
'--png_compression','-z',
|
||||||
'--text_mask','-tm',
|
'--text_mask','-tm',
|
||||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||||
|
'!models','!switch','!import_model','!edit_model','!del_model',
|
||||||
'!mask',
|
'!mask',
|
||||||
'!models','!switch','!import_model','!edit_model'
|
|
||||||
)
|
)
|
||||||
MODEL_COMMANDS = (
|
MODEL_COMMANDS = (
|
||||||
'!switch',
|
'!switch',
|
||||||
'!edit_model',
|
'!edit_model',
|
||||||
|
'!del_model',
|
||||||
)
|
)
|
||||||
WEIGHT_COMMANDS = (
|
WEIGHT_COMMANDS = (
|
||||||
'!import_model',
|
'!import_model',
|
||||||
@ -218,9 +219,24 @@ class Completer(object):
|
|||||||
pydoc.pager('\n'.join(lines))
|
pydoc.pager('\n'.join(lines))
|
||||||
|
|
||||||
def set_line(self,line)->None:
|
def set_line(self,line)->None:
|
||||||
|
'''
|
||||||
|
Set the default string displayed in the next line of input.
|
||||||
|
'''
|
||||||
self.linebuffer = line
|
self.linebuffer = line
|
||||||
readline.redisplay()
|
readline.redisplay()
|
||||||
|
|
||||||
|
def add_model(self,model_name:str)->None:
|
||||||
|
'''
|
||||||
|
add a model name to the completion list
|
||||||
|
'''
|
||||||
|
self.models.append(model_name)
|
||||||
|
|
||||||
|
def del_model(self,model_name:str)->None:
|
||||||
|
'''
|
||||||
|
removes a model name from the completion list
|
||||||
|
'''
|
||||||
|
self.models.remove(model_name)
|
||||||
|
|
||||||
def _seed_completions(self, text, state):
|
def _seed_completions(self, text, state):
|
||||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||||
if m:
|
if m:
|
||||||
|
@ -35,4 +35,4 @@ realesrgan
|
|||||||
git+https://github.com/openai/CLIP.git@main#egg=clip
|
git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||||
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
|
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
|
||||||
git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
-e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
||||||
|
@ -424,6 +424,15 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
|||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
operation = None
|
operation = None
|
||||||
|
|
||||||
|
elif command.startswith('!del'):
|
||||||
|
path = shlex.split(command)
|
||||||
|
if len(path) < 2:
|
||||||
|
print('** please provide the name of a model')
|
||||||
|
else:
|
||||||
|
del_config(path[1], gen, opt, completer)
|
||||||
|
completer.add_history(command)
|
||||||
|
operation = None
|
||||||
|
|
||||||
elif command.startswith('!fetch'):
|
elif command.startswith('!fetch'):
|
||||||
file_path = command.replace('!fetch','',1).strip()
|
file_path = command.replace('!fetch','',1).strip()
|
||||||
retrieve_dream_command(opt,file_path,completer)
|
retrieve_dream_command(opt,file_path,completer)
|
||||||
@ -484,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
|||||||
new_config['config'] = input('Configuration file for this model: ')
|
new_config['config'] = input('Configuration file for this model: ')
|
||||||
done = os.path.exists(new_config['config'])
|
done = os.path.exists(new_config['config'])
|
||||||
|
|
||||||
|
done = False
|
||||||
|
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
|
||||||
|
while not done:
|
||||||
|
vae = input('VAE autoencoder file for this model [None]: ')
|
||||||
|
if os.path.exists(vae):
|
||||||
|
new_config['vae'] = vae
|
||||||
|
done = True
|
||||||
|
else:
|
||||||
|
done = len(vae)==0
|
||||||
|
|
||||||
completer.complete_extensions(None)
|
completer.complete_extensions(None)
|
||||||
|
|
||||||
for field in ('width','height'):
|
for field in ('width','height'):
|
||||||
@ -498,9 +517,25 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
|||||||
except:
|
except:
|
||||||
print('** Please enter a valid integer between 64 and 2048')
|
print('** Please enter a valid integer between 64 and 2048')
|
||||||
|
|
||||||
if write_config_file(opt.conf, gen, model_name, new_config):
|
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||||
gen.set_model(model_name)
|
|
||||||
|
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
|
||||||
|
completer.add_model(model_name)
|
||||||
|
|
||||||
|
def del_config(model_name:str, gen, opt, completer):
|
||||||
|
current_model = gen.model_name
|
||||||
|
if model_name == current_model:
|
||||||
|
print("** Can't delete active model. !switch to another model first. **")
|
||||||
|
return
|
||||||
|
yaml_str = gen.model_cache.del_model(model_name)
|
||||||
|
|
||||||
|
tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp')
|
||||||
|
with open(tmpfile, 'w') as outfile:
|
||||||
|
outfile.write(yaml_str)
|
||||||
|
os.rename(tmpfile,opt.conf)
|
||||||
|
print(f'** {model_name} deleted')
|
||||||
|
completer.del_model(model_name)
|
||||||
|
|
||||||
def edit_config(model_name:str, gen, opt, completer):
|
def edit_config(model_name:str, gen, opt, completer):
|
||||||
config = gen.model_cache.config
|
config = gen.model_cache.config
|
||||||
|
|
||||||
@ -512,33 +547,46 @@ def edit_config(model_name:str, gen, opt, completer):
|
|||||||
|
|
||||||
conf = config[model_name]
|
conf = config[model_name]
|
||||||
new_config = {}
|
new_config = {}
|
||||||
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae'))
|
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
|
||||||
for field in ('description', 'weights', 'config', 'width','height'):
|
for field in ('description', 'weights', 'vae', 'config', 'width','height'):
|
||||||
completer.linebuffer = str(conf[field]) if field in conf else ''
|
completer.linebuffer = str(conf[field]) if field in conf else ''
|
||||||
new_value = input(f'{field}: ')
|
new_value = input(f'{field}: ')
|
||||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
||||||
|
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||||
completer.complete_extensions(None)
|
completer.complete_extensions(None)
|
||||||
|
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
|
||||||
if write_config_file(opt.conf, gen, model_name, new_config, clobber=True):
|
|
||||||
gen.set_model(model_name)
|
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
|
||||||
|
current_model = gen.model_name
|
||||||
|
|
||||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
|
|
||||||
op = 'modify' if clobber else 'import'
|
op = 'modify' if clobber else 'import'
|
||||||
print('\n>> New configuration:')
|
print('\n>> New configuration:')
|
||||||
|
if make_default:
|
||||||
|
new_config['default'] = True
|
||||||
print(yaml.dump({model_name:new_config}))
|
print(yaml.dump({model_name:new_config}))
|
||||||
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
print('>> Verifying that new model loads...')
|
||||||
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
||||||
|
assert gen.set_model(model_name) is not None, 'model failed to load'
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(f'** configuration failed: {str(e)}')
|
print(f'** aborting **')
|
||||||
|
gen.model_cache.del_model(model_name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if make_default:
|
||||||
|
print('making this default')
|
||||||
|
gen.model_cache.set_default_model(model_name)
|
||||||
|
|
||||||
|
gen.model_cache.commit(conf_path)
|
||||||
|
|
||||||
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
|
do_switch = input(f'Keep model loaded? [y]')
|
||||||
with open(tmpfile, 'w') as outfile:
|
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||||
outfile.write(yaml_str)
|
pass
|
||||||
os.rename(tmpfile,conf_path)
|
else:
|
||||||
|
gen.set_model(current_model)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def do_textmask(gen, opt, callback):
|
def do_textmask(gen, opt, callback):
|
||||||
@ -598,7 +646,10 @@ def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command):
|
|||||||
original_file = original_file if os.path.exists(original_file) else os.path.join(opt.outdir,original_file)
|
original_file = original_file if os.path.exists(original_file) else os.path.join(opt.outdir,original_file)
|
||||||
new_file = new_file if os.path.exists(new_file) else os.path.join(opt.outdir,new_file)
|
new_file = new_file if os.path.exists(new_file) else os.path.join(opt.outdir,new_file)
|
||||||
meta = retrieve_metadata(original_file)['sd-metadata']
|
meta = retrieve_metadata(original_file)['sd-metadata']
|
||||||
img_data = meta['image']
|
if 'image' not in meta:
|
||||||
|
meta = metadata_dumps(opt,seeds=[opt.seed])['image']
|
||||||
|
meta['image'] = {}
|
||||||
|
img_data = meta.get('image')
|
||||||
pp = img_data.get('postprocessing',[]) or []
|
pp = img_data.get('postprocessing',[]) or []
|
||||||
pp.append(
|
pp.append(
|
||||||
{
|
{
|
||||||
@ -748,26 +799,38 @@ def retrieve_dream_command(opt,command,completer):
|
|||||||
will retrieve and format the dream command used to generate the image,
|
will retrieve and format the dream command used to generate the image,
|
||||||
and pop it into the readline buffer (linux, Mac), or print out a comment
|
and pop it into the readline buffer (linux, Mac), or print out a comment
|
||||||
for cut-and-paste (windows)
|
for cut-and-paste (windows)
|
||||||
|
|
||||||
Given a wildcard path to a folder with image png files,
|
Given a wildcard path to a folder with image png files,
|
||||||
will retrieve and format the dream command used to generate the images,
|
will retrieve and format the dream command used to generate the images,
|
||||||
and save them to a file commands.txt for further processing
|
and save them to a file commands.txt for further processing
|
||||||
'''
|
'''
|
||||||
if len(command) == 0:
|
if len(command) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
tokens = command.split()
|
tokens = command.split()
|
||||||
if len(tokens) > 1:
|
dir,basename = os.path.split(tokens[0])
|
||||||
outfilepath = tokens[1]
|
|
||||||
else:
|
|
||||||
outfilepath = "commands.txt"
|
|
||||||
|
|
||||||
file_path = tokens[0]
|
|
||||||
dir,basename = os.path.split(file_path)
|
|
||||||
if len(dir) == 0:
|
if len(dir) == 0:
|
||||||
dir = opt.outdir
|
path = os.path.join(opt.outdir,basename)
|
||||||
|
else:
|
||||||
outdir,outname = os.path.split(outfilepath)
|
path = tokens[0]
|
||||||
if len(outdir) == 0:
|
|
||||||
outfilepath = os.path.join(dir,outname)
|
if len(tokens) > 1:
|
||||||
|
return write_commands(opt, path, tokens[1])
|
||||||
|
|
||||||
|
cmd = ''
|
||||||
|
try:
|
||||||
|
cmd = dream_cmd_from_png(path)
|
||||||
|
except OSError:
|
||||||
|
print(f'## {tokens[0]}: file could not be read')
|
||||||
|
except (KeyError, AttributeError, IndexError):
|
||||||
|
print(f'## {tokens[0]}: file has no metadata')
|
||||||
|
except:
|
||||||
|
print(f'## {tokens[0]}: file could not be processed')
|
||||||
|
if len(cmd)>0:
|
||||||
|
completer.set_line(cmd)
|
||||||
|
|
||||||
|
def write_commands(opt, file_path:str, outfilepath:str):
|
||||||
|
dir,basename = os.path.split(file_path)
|
||||||
try:
|
try:
|
||||||
paths = list(Path(dir).glob(basename))
|
paths = list(Path(dir).glob(basename))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -775,28 +838,24 @@ def retrieve_dream_command(opt,command,completer):
|
|||||||
return
|
return
|
||||||
|
|
||||||
commands = []
|
commands = []
|
||||||
|
cmd = None
|
||||||
for path in paths:
|
for path in paths:
|
||||||
try:
|
try:
|
||||||
cmd = dream_cmd_from_png(path)
|
cmd = dream_cmd_from_png(path)
|
||||||
except OSError:
|
|
||||||
print(f'## {path}: file could not be read')
|
|
||||||
continue
|
|
||||||
except (KeyError, AttributeError, IndexError):
|
except (KeyError, AttributeError, IndexError):
|
||||||
print(f'## {path}: file has no metadata')
|
print(f'## {path}: file has no metadata')
|
||||||
continue
|
|
||||||
except:
|
except:
|
||||||
print(f'## {path}: file could not be processed')
|
print(f'## {path}: file could not be processed')
|
||||||
continue
|
if cmd:
|
||||||
|
commands.append(f'# {path}')
|
||||||
commands.append(f'# {path}')
|
commands.append(cmd)
|
||||||
commands.append(cmd)
|
if len(commands)>0:
|
||||||
|
dir,basename = os.path.split(outfilepath)
|
||||||
with open(outfilepath, 'w', encoding='utf-8') as f:
|
if len(dir)==0:
|
||||||
f.write('\n'.join(commands))
|
outfilepath = os.path.join(opt.outdir,basename)
|
||||||
print(f'>> File {outfilepath} with commands created')
|
with open(outfilepath, 'w', encoding='utf-8') as f:
|
||||||
|
f.write('\n'.join(commands))
|
||||||
if len(commands) == 2:
|
print(f'>> File {outfilepath} with commands created')
|
||||||
completer.set_line(commands[1])
|
|
||||||
|
|
||||||
######################################
|
######################################
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user