resolved conflicts with main

This commit is contained in:
Lincoln Stein
2023-07-27 15:11:25 -04:00
275 changed files with 11706 additions and 8208 deletions

View File

@ -1,15 +1,6 @@
"""
Initialization file for invokeai.backend
"""
from .generator import (
InvokeAIGeneratorBasicParams,
InvokeAIGenerator,
InvokeAIGeneratorOutput,
Img2Img,
Inpaint
)
from .model_management import (
ModelManager, ModelCache, BaseModelType,
ModelType, SubModelType, ModelInfo
)
from .safety_checker import SafetyChecker
from .generator import InvokeAIGeneratorBasicParams, InvokeAIGenerator, InvokeAIGeneratorOutput, Img2Img, Inpaint
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo
from .model_management.models import SilenceWarnings

View File

@ -28,68 +28,71 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
import invokeai.backend.util.logging as logger
from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.schedulers import SCHEDULER_MAP
downsampling = 8
@dataclass
class InvokeAIGeneratorBasicParams:
seed: Optional[int]=None
width: int=512
height: int=512
cfg_scale: float=7.5
steps: int=20
ddim_eta: float=0.0
scheduler: str='ddim'
precision: str='float16'
perlin: float=0.0
threshold: float=0.0
seamless: bool=False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
h_symmetry_time_pct: Optional[float]=None
v_symmetry_time_pct: Optional[float]=None
seed: Optional[int] = None
width: int = 512
height: int = 512
cfg_scale: float = 7.5
steps: int = 20
ddim_eta: float = 0.0
scheduler: str = "ddim"
precision: str = "float16"
perlin: float = 0.0
threshold: float = 0.0
seamless: bool = False
seamless_axes: List[str] = field(default_factory=lambda: ["x", "y"])
h_symmetry_time_pct: Optional[float] = None
v_symmetry_time_pct: Optional[float] = None
variation_amount: float = 0.0
with_variations: list=field(default_factory=list)
safety_checker: Optional[SafetyChecker]=None
with_variations: list = field(default_factory=list)
@dataclass
class InvokeAIGeneratorOutput:
'''
"""
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
operation, including the image, its seed, the model name used to generate the image
and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes)
'''
"""
image: Image.Image
seed: int
model_hash: str
attention_maps_images: List[Image.Image]
params: Namespace
# we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work.
class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self,
model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
**kwargs,
):
self.model_info=model_info
self.params=params
def __init__(
self,
model_info: dict,
params: InvokeAIGeneratorBasicParams = InvokeAIGeneratorBasicParams(),
**kwargs,
):
self.model_info = model_info
self.params = params
self.kwargs = kwargs
def generate(
self,
conditioning: tuple,
scheduler,
callback: Optional[Callable]=None,
step_callback: Optional[Callable]=None,
iterations: int=1,
callback: Optional[Callable] = None,
step_callback: Optional[Callable] = None,
iterations: int = 1,
**keyword_args,
)->Iterator[InvokeAIGeneratorOutput]:
'''
) -> Iterator[InvokeAIGeneratorOutput]:
"""
Return an iterator across the indicated number of generations.
Each time the iterator is called it will return an InvokeAIGeneratorOutput
object. Use like this:
@ -109,7 +112,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
for o in outputs:
print(o.image, o.seed)
'''
"""
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
@ -120,22 +123,21 @@ class InvokeAIGenerator(metaclass=ABCMeta):
gen_class = self._generator_class()
generator = gen_class(model, self.params.precision, **self.kwargs)
if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
generator_args.get('with_variations')
)
generator.set_variation(
generator_args.get("seed"),
generator_args.get("variation_amount"),
generator_args.get("with_variations"),
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
configure_model_padding(
component, generator_args.get("seamless", False), generator_args.get("seamless_axes")
)
else:
configure_model_padding(model,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
configure_model_padding(
model, generator_args.get("seamless", False), generator_args.get("seamless_axes")
)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
@ -149,66 +151,66 @@ class InvokeAIGenerator(metaclass=ABCMeta):
image=results[0][0],
seed=results[0][1],
attention_maps_images=results[0][2],
model_hash = model_hash,
params=Namespace(model_name=model_name,**generator_args),
model_hash=model_hash,
params=Namespace(model_name=model_name, **generator_args),
)
if callback:
callback(output)
yield output
@classmethod
def schedulers(self)->List[str]:
'''
def schedulers(self) -> List[str]:
"""
Return list of all the schedulers that we currently handle.
'''
"""
return list(SCHEDULER_MAP.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision)
@classmethod
def _generator_class(cls)->Type[Generator]:
'''
def _generator_class(cls) -> Type[Generator]:
"""
In derived classes return the name of the generator to apply.
If you don't override will return the name of the derived
class, which nicely parallels the generator class names.
'''
"""
return Generator
# ------------------------------------
class Img2Img(InvokeAIGenerator):
def generate(self,
init_image: Union[Image.Image, torch.FloatTensor],
strength: float=0.75,
**keyword_args
)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(init_image=init_image,
strength=strength,
**keyword_args
)
def generate(
self, init_image: Union[Image.Image, torch.FloatTensor], strength: float = 0.75, **keyword_args
) -> Iterator[InvokeAIGeneratorOutput]:
return super().generate(init_image=init_image, strength=strength, **keyword_args)
@classmethod
def _generator_class(cls):
from .img2img import Img2Img
return Img2Img
# ------------------------------------
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img):
def generate(self,
mask_image: Union[Image.Image, torch.FloatTensor],
# Seam settings - when 0, doesn't fill seam
seam_size: int = 96,
seam_blur: int = 16,
seam_strength: float = 0.7,
seam_steps: int = 30,
tile_size: int = 32,
inpaint_replace=False,
infill_method=None,
inpaint_width=None,
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
**keyword_args
)->Iterator[InvokeAIGeneratorOutput]:
def generate(
self,
mask_image: Union[Image.Image, torch.FloatTensor],
# Seam settings - when 0, doesn't fill seam
seam_size: int = 96,
seam_blur: int = 16,
seam_strength: float = 0.7,
seam_steps: int = 30,
tile_size: int = 32,
inpaint_replace=False,
infill_method=None,
inpaint_width=None,
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
**keyword_args,
) -> Iterator[InvokeAIGeneratorOutput]:
return super().generate(
mask_image=mask_image,
seam_size=seam_size,
@ -221,13 +223,16 @@ class Inpaint(Img2Img):
inpaint_width=inpaint_width,
inpaint_height=inpaint_height,
inpaint_fill=inpaint_fill,
**keyword_args
**keyword_args,
)
@classmethod
def _generator_class(cls):
from .inpaint import Inpaint
return Inpaint
class Generator:
downsampling_factor: int
latent_channels: int
@ -240,7 +245,6 @@ class Generator:
self.seed = None
self.latent_channels = model.unet.config.in_channels
self.downsampling_factor = downsampling # BUG: should come from model or config
self.safety_checker = None
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
@ -254,9 +258,7 @@ class Generator:
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
"""
raise NotImplementedError(
"image_iterator() must be implemented in a descendent class"
)
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
def set_variation(self, seed, variation_amount, with_variations):
self.seed = seed
@ -277,17 +279,13 @@ class Generator:
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
safety_checker: SafetyChecker=None,
free_gpu_mem: bool = False,
**kwargs,
):
scope = nullcontext
self.safety_checker = safety_checker
self.free_gpu_mem = free_gpu_mem
attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(
saver.get_stacked_maps_image()
)
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
make_image = self.get_make_image(
sampler=sampler,
init_image=init_image,
@ -329,17 +327,10 @@ class Generator:
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
image = make_image(x_T, seed)
if self.safety_checker is not None:
image = self.safety_checker.check(image)
results.append([image, seed, attention_maps_images])
if image_callback is not None:
attention_maps_image = (
None
if len(attention_maps_images) == 0
else attention_maps_images[-1]
)
attention_maps_image = None if len(attention_maps_images) == 0 else attention_maps_images[-1]
image_callback(
image,
seed,
@ -350,9 +341,7 @@ class Generator:
seed = self.new_seed()
# Free up memory from the last generation.
clear_cuda_cache = (
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
)
clear_cuda_cache = kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
if clear_cuda_cache is not None:
clear_cuda_cache()
@ -379,14 +368,8 @@ class Generator:
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = (
init_mask.getchannel("A")
if init_mask.mode == "RGBA"
else init_mask.convert("L")
)
pil_init_image = init_image.convert(
"RGBA"
) # Add an alpha channel if one doesn't exist
pil_init_mask = init_mask.getchannel("A") if init_mask.mode == "RGBA" else init_mask.convert("L")
pil_init_image = init_image.convert("RGBA") # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching.
init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
@ -412,10 +395,7 @@ class Generator:
np_matched_result[:, :, :] = (
(
(
(
np_matched_result[:, :, :].astype(np.float32)
- gen_means[None, None, :]
)
(np_matched_result[:, :, :].astype(np.float32) - gen_means[None, None, :])
/ gen_std[None, None, :]
)
* init_std[None, None, :]
@ -441,9 +421,7 @@ class Generator:
else:
blurred_init_mask = pil_init_mask
multiplied_blurred_init_mask = ImageChops.multiply(
blurred_init_mask, self.pil_image.split()[-1]
)
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
# Paste original on color-corrected generation (using blurred mask)
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
@ -469,10 +447,7 @@ class Generator:
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
latents_ubyte = (
((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
).cpu()
return Image.fromarray(latents_ubyte.numpy())
@ -502,9 +477,7 @@ class Generator:
temp_height = int((height + 7) / 8) * 8
noise = torch.stack(
[
rand_perlin_2d(
(temp_height, temp_width), (8, 8), device=self.model.device
).to(fixdevice)
rand_perlin_2d((temp_height, temp_width), (8, 8), device=self.model.device).to(fixdevice)
for _ in range(input_channels)
],
dim=0,
@ -581,8 +554,6 @@ class Generator:
device=device,
)
if self.perlin > 0.0:
perlin_noise = self.get_perlin_noise(
width // self.downsampling_factor, height // self.downsampling_factor
)
perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x

View File

@ -77,10 +77,7 @@ class Img2Img(Generator):
callback=step_callback,
seed=seed,
)
if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
@ -91,7 +88,5 @@ class Img2Img(Generator):
x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
shape[3], shape[2]
)
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
return x

View File

@ -68,15 +68,11 @@ class Inpaint(Img2Img):
return im
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = PatchMatch.inpaint(
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
)
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
im_patched = Image.fromarray(im_patched_np, mode="RGB")
return im_patched
def tile_fill_missing(
self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
) -> Image.Image:
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
# Only fill if there's an alpha layer
if im.mode != "RGBA":
return im
@ -127,15 +123,11 @@ class Inpaint(Img2Img):
return si
def mask_edge(
self, mask: Image.Image, edge_size: int, edge_blur: int
) -> Image.Image:
def mask_edge(self, mask: Image.Image, edge_size: int, edge_blur: int) -> Image.Image:
npimg = np.asarray(mask, dtype=np.uint8)
# Detect any partially transparent regions
npgradient = np.uint8(
255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))
)
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
# Detect hard edges
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
@ -144,9 +136,7 @@ class Inpaint(Img2Img):
npmask = npgradient + npedge
# Expand
npmask = cv2.dilate(
npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2)
)
npmask = cv2.dilate(npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2))
new_mask = Image.fromarray(npmask)
@ -242,25 +232,19 @@ class Inpaint(Img2Img):
if infill_method == "patchmatch" and PatchMatch.patchmatch_available():
init_filled = self.infill_patchmatch(self.pil_image.copy())
elif infill_method == "tile":
init_filled = self.tile_fill_missing(
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
)
init_filled = self.tile_fill_missing(self.pil_image.copy(), seed=self.seed, tile_size=tile_size)
elif infill_method == "solid":
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
init_filled = Image.alpha_composite(solid_bg, init_image)
else:
raise ValueError(
f"Non-supported infill type {infill_method}", infill_method
)
raise ValueError(f"Non-supported infill type {infill_method}", infill_method)
init_filled.paste(init_image, (0, 0), init_image.split()[-1])
# Resize if requested for inpainting
if inpaint_width and inpaint_height:
init_filled = init_filled.resize((inpaint_width, inpaint_height))
debug_image(
init_filled, "init_filled", debug_status=self.enable_image_debugging
)
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
# Create init tensor
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
@ -289,9 +273,7 @@ class Inpaint(Img2Img):
"mask_image AFTER multiply with pil_image",
debug_status=self.enable_image_debugging,
)
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(
mask_image, normalize=False
)
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
else:
mask: torch.FloatTensor = mask_image
@ -302,9 +284,9 @@ class Inpaint(Img2Img):
# todo: support cross-attention control
uc, c, _ = conditioning
conditioning_data = ConditioningData(
uc, c, cfg_scale
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
conditioning_data = ConditioningData(uc, c, cfg_scale).add_scheduler_args_if_applicable(
pipeline.scheduler, eta=ddim_eta
)
def make_image(x_T: torch.Tensor, seed: int):
pipeline_output = pipeline.inpaint_from_embeddings(
@ -318,15 +300,10 @@ class Inpaint(Img2Img):
seed=seed,
)
if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
result = self.postprocess_size_and_mask(
pipeline.numpy_to_pil(pipeline_output.images)[0]
)
result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0])
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
if seam_size > 0:

View File

@ -8,9 +8,7 @@ from .txt2mask import Txt2Mask
from .util import InitImageResizer, make_grid
def debug_image(
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
):
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False):
if not debug_status:
return

View File

@ -0,0 +1,34 @@
"""
This module defines a singleton object, "invisible_watermark" that
wraps the invisible watermark model. It respects the global "invisible_watermark"
configuration variable, that allows the watermarking to be supressed.
"""
import numpy as np
import cv2
from PIL import Image
from imwatermark import WatermarkEncoder
from invokeai.app.services.config import InvokeAIAppConfig
import invokeai.backend.util.logging as logger
config = InvokeAIAppConfig.get_config()
class InvisibleWatermark:
"""
Wrapper around InvisibleWatermark module.
"""
@classmethod
def invisible_watermark_available(self) -> bool:
return config.invisible_watermark
@classmethod
def add_watermark(self, image: Image, watermark_text: str) -> Image:
if not self.invisible_watermark_available():
return image
logger.debug(f'Applying invisible watermark "{watermark_text}"')
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
encoder = WatermarkEncoder()
encoder.set_watermark("bytes", watermark_text.encode("utf-8"))
bgr_encoded = encoder.encode(bgr, "dwtDct")
return Image.fromarray(cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)).convert("RGBA")

View File

@ -7,8 +7,10 @@ be suppressed or deferred
import numpy as np
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
class PatchMatch:
"""
Thin class wrapper around the patchmatch function.

View File

@ -34,9 +34,7 @@ class PngWriter:
# saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output
def save_image_and_prompt_to_png(
self, image, dream_prompt, name, metadata=None, compress_level=6
):
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6):
path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo()
info.add_text("Dream", dream_prompt)
@ -114,8 +112,6 @@ class PromptFormatter:
if opt.variation_amount > 0:
switches.append(f"-v{opt.variation_amount}")
if opt.with_variations:
formatted_variations = ",".join(
f"{seed}:{weight}" for seed, weight in opt.with_variations
)
formatted_variations = ",".join(f"{seed}:{weight}" for seed, weight in opt.with_variations)
switches.append(f"-V{formatted_variations}")
return " ".join(switches)

View File

@ -0,0 +1,64 @@
"""
This module defines a singleton object, "safety_checker" that
wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
import numpy as np
from PIL import Image
from invokeai.backend import SilenceWarnings
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.devices import choose_torch_device
import invokeai.backend.util.logging as logger
config = InvokeAIAppConfig.get_config()
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
class SafetyChecker:
"""
Wrapper around SafetyChecker model.
"""
safety_checker = None
feature_extractor = None
tried_load: bool = False
@classmethod
def _load_safety_checker(self):
if self.tried_load:
return
if config.nsfw_checker:
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
logger.info("NSFW checker initialized")
except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}")
else:
logger.info("NSFW checker loading disabled")
self.tried_load = True
@classmethod
def safety_checker_available(self) -> bool:
self._load_safety_checker()
return self.safety_checker is not None
@classmethod
def has_nsfw_concept(self, image: Image) -> bool:
if not self.safety_checker_available():
return False
device = choose_torch_device()
features = self.feature_extractor([image], return_tensors="pt")
features.to(device)
self.safety_checker.to(device)
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
with SilenceWarnings():
checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values)
return has_nsfw_concept[0]

View File

@ -5,12 +5,8 @@ def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(
input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]
)
working = nn.functional.pad(
working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]
)
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
return nn.functional.conv2d(
working,
weight,
@ -32,18 +28,14 @@ def configure_model_padding(model, seamless, seamless_axes):
if seamless:
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = (
"circular" if ("x" in seamless_axes) else "constant"
)
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = (
"circular" if ("y" in seamless_axes) else "constant"
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,

View File

@ -39,23 +39,18 @@ CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
CLIPSEG_SIZE = 352
config = InvokeAIAppConfig.get_config()
class SegmentedGrayscale(object):
def __init__(self, image: Image, heatmap: torch.Tensor):
self.heatmap = heatmap
self.image = image
def to_grayscale(self, invert: bool = False) -> Image:
return self._rescale(
Image.fromarray(
np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)
)
)
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
def to_mask(self, threshold: float = 0.5) -> Image:
discrete_heatmap = self.heatmap.lt(threshold).int()
return self._rescale(
Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L")
)
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L"))
def to_transparent(self, invert: bool = False) -> Image:
transparent_image = self.image.copy()
@ -67,11 +62,7 @@ class SegmentedGrayscale(object):
# unscales and uncrops the 352x352 heatmap so that it matches the image again
def _rescale(self, heatmap: Image) -> Image:
size = (
self.image.width
if (self.image.width > self.image.height)
else self.image.height
)
size = self.image.width if (self.image.width > self.image.height) else self.image.height
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
return resized_image.crop((0, 0, self.image.width, self.image.height))
@ -87,12 +78,8 @@ class Txt2Mask(object):
# BUG: we are not doing anything with the device option at this time
self.device = device
self.processor = AutoProcessor.from_pretrained(
CLIPSEG_MODEL, cache_dir=config.cache_dir
)
self.model = CLIPSegForImageSegmentation.from_pretrained(
CLIPSEG_MODEL, cache_dir=config.cache_dir
)
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
@torch.no_grad()
def segment(self, image, prompt: str) -> SegmentedGrayscale:
@ -107,9 +94,7 @@ class Txt2Mask(object):
image = ImageOps.exif_transpose(image)
img = self._scale_and_crop(image)
inputs = self.processor(
text=[prompt], images=[img], padding=True, return_tensors="pt"
)
inputs = self.processor(text=[prompt], images=[img], padding=True, return_tensors="pt")
outputs = self.model(**inputs)
heatmap = torch.sigmoid(outputs.logits)
return SegmentedGrayscale(image, heatmap)

View File

@ -0,0 +1,36 @@
"""
Check that the invokeai_root is correctly configured and exit if not.
"""
import sys
from invokeai.app.services.config import (
InvokeAIAppConfig,
)
def check_invokeai_root(config: InvokeAIAppConfig):
try:
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
assert config.models_path.exists(), f"{config.models_path} not found"
for model in [
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
"bert-base-uncased",
"clip-vit-large-patch14",
"sd-vae-ft-mse",
"stable-diffusion-2-clip",
"stable-diffusion-safety-checker",
]:
path = config.models_path / f"core/convert/{model}"
assert path.exists(), f"{path} is missing"
except Exception as e:
print()
print(f"An exception has occurred: {str(e)}")
print("== STARTUP ABORTED ==")
print("** One or more necessary files is missing from your InvokeAI root directory **")
print("** Please rerun the configuration script to fix this problem. **")
print("** From the launcher, selection option [7]. **")
print(
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
)
input("Press any key to continue...")
sys.exit(0)

View File

@ -13,8 +13,8 @@ import os
import shutil
import textwrap
import traceback
import warnings
import yaml
import warnings
from argparse import Namespace
from pathlib import Path
from shutil import get_terminal_size
@ -32,6 +32,7 @@ from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import (
CLIPTextModel,
CLIPTextConfig,
CLIPTokenizer,
AutoFeatureExtractor,
BertTokenizerFast,
@ -44,6 +45,7 @@ from invokeai.app.services.config import (
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import (
SingleSelectColumns,
CenteredButtonPress,
FileBox,
IntTitleSlider,
@ -58,9 +60,7 @@ from invokeai.backend.install.model_install_backend import (
InstallSelections,
ModelInstall,
)
from invokeai.backend.model_management.model_probe import (
ModelType, BaseModelType
)
from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
@ -75,7 +75,7 @@ Model_dir = "models"
Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path
PRECISION_CHOICES = ['auto','float16','float32']
PRECISION_CHOICES = ["auto", "float16", "float32"]
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# This is the InvokeAI initialization file, which contains command-line default values.
@ -83,7 +83,8 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# or renaming it and then running invokeai-configure again.
"""
logger=InvokeAILogger.getLogger()
logger = InvokeAILogger.getLogger()
# --------------------------------------------
def postscript(errors: None):
@ -106,7 +107,9 @@ Add the '--help' argument to see all of the command-line switches available for
"""
else:
message = "\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
message = (
"\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
)
for err in errors:
message += f"\t - {err}\n"
message += "Please check the logs above and correct any issues."
@ -167,9 +170,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
logger.info(f"Installing {label} model file {model_url}...")
if not os.path.exists(model_dest):
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
request.urlretrieve(
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
)
request.urlretrieve(model_url, model_dest, ProgressBar(os.path.basename(model_dest)))
logger.info("...downloaded successfully")
else:
logger.info("...exists")
@ -180,81 +181,93 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
def download_conversion_models():
target_dir = config.root_path / 'models/core/convert'
target_dir = config.root_path / "models/core/convert"
kwargs = dict() # for future use
try:
logger.info('Downloading core tokenizers and text encoders')
logger.info("Downloading core tokenizers and text encoders")
# bert
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
bert.save_pretrained(target_dir / "bert-base-uncased", safe_serialization=True)
# sd-1
repo_id = 'openai/clip-vit-large-patch14'
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
repo_id = "openai/clip-vit-large-patch14"
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / "clip-vit-large-patch14")
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / "clip-vit-large-patch14")
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "tokenizer", safe_serialization=True)
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "text_encoder", safe_serialization=True)
# sd-xl - tokenizer_2
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
_, model_name = repo_id.split("/")
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
# VAE
logger.info('Downloading stable diffusion VAE')
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
logger.info("Downloading stable diffusion VAE")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", **kwargs)
vae.save_pretrained(target_dir / "sd-vae-ft-mse", safe_serialization=True)
# safety checking
logger.info('Downloading safety checker')
logger.info("Downloading safety checker")
repo_id = "CompVis/stable-diffusion-safety-checker"
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
# ---------------------------------------------
def download_realesrgan():
logger.info("Installing ESRGAN Upscaling models...")
URLs = [
dict(
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
description = "RealESRGAN_x4plus.pth",
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
description="RealESRGAN_x4plus.pth",
),
dict(
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
description = "RealESRGAN_x4plus_anime_6B.pth",
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
description="RealESRGAN_x4plus_anime_6B.pth",
),
dict(
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description = "ESRGAN_SRx4_DF2KOST_official.pth",
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description="ESRGAN_SRx4_DF2KOST_official.pth",
),
dict(
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
dest= "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
description = "RealESRGAN_x2plus.pth",
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
description="RealESRGAN_x2plus.pth",
),
]
for model in URLs:
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
# ---------------------------------------------
def download_support_models():
download_realesrgan()
download_conversion_models()
# -------------------------------------
def get_root(root: str = None) -> str:
if root:
@ -264,6 +277,7 @@ def get_root(root: str = None) -> str:
else:
return str(config.root_path)
# -------------------------------------
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
# for responsive resizing - disabled
@ -272,14 +286,14 @@ class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
def create(self):
program_opts = self.parentApp.program_opts
old_opts = self.parentApp.invokeai_opts
first_time = not (config.root_path / 'invokeai.yaml').exists()
first_time = not (config.root_path / "invokeai.yaml").exists()
access_token = HfFolder.get_token()
window_width, window_height = get_terminal_size()
label = """Configure startup settings. You can come back and change these later.
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
Use cursor arrows to make a checkbox selection, and space to toggle.
"""
for i in textwrap.wrap(label,width=window_width-6):
for i in textwrap.wrap(label, width=window_width - 6):
self.add_widget_intelligent(
npyscreen.FixedText,
value=i,
@ -287,50 +301,9 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
color="CONTROL",
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="== BASIC OPTIONS ==",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Select an output directory for images:",
editable=False,
color="CONTROL",
)
self.outdir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="(<tab> autocompletes, ctrl-N advances):",
value=str(default_output_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=40,
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Activate the NSFW checker to blur images showing potential sexual imagery:",
editable=False,
color="CONTROL",
)
self.nsfw_checker = self.add_widget_intelligent(
npyscreen.Checkbox,
name="NSFW checker",
value=old_opts.nsfw_checker,
relx=5,
scroll_exit=True,
)
self.nextrely += 1
label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
for line in textwrap.wrap(label,width=window_width-6):
for line in textwrap.wrap(label, width=window_width - 6):
self.add_widget_intelligent(
npyscreen.FixedText,
value=line,
@ -347,15 +320,6 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="== ADVANCED OPTIONS ==",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="GPU Management",
@ -369,34 +333,47 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
npyscreen.Checkbox,
name="Free GPU memory after each generation",
value=old_opts.free_gpu_mem,
max_width=45,
relx=5,
scroll_exit=True,
)
self.nextrely -= 1
self.xformers_enabled = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Enable xformers support if available",
name="Enable xformers support",
value=old_opts.xformers_enabled,
relx=5,
max_width=30,
relx=50,
scroll_exit=True,
)
self.nextrely -= 1
self.always_use_cpu = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force CPU to be used on GPU systems",
value=old_opts.always_use_cpu,
relx=5,
relx=80,
scroll_exit=True,
)
precision = old_opts.precision or (
"float32" if program_opts.full_precision else "auto"
precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto")
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Floating Point Precision",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.precision = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
columns = 2,
SingleSelectColumns,
columns=3,
name="Precision",
values=PRECISION_CHOICES,
value=PRECISION_CHOICES.index(precision),
begin_entry_at=3,
max_height=len(PRECISION_CHOICES) + 1,
max_height=2,
max_width=80,
scroll_exit=True,
)
self.max_cache_size = self.add_widget_intelligent(
@ -409,40 +386,38 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models (<tab> autocompletes, ctrl-N advances):",
editable=False,
color="CONTROL",
)
self.autoimport_dirs = {}
self.autoimport_dirs['autoimport_dir'] = self.add_widget_intelligent(
FileBox,
name=f'Autoimport Folder',
value=str(config.root_path / config.autoimport_dir),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=32,
max_height = 3,
scroll_exit=True
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="== LICENSE ==",
begin_entry_at=0,
editable=False,
color="CONTROL",
self.outdir = self.add_widget_intelligent(
FileBox,
name="Output directory for images (<tab> autocompletes, ctrl-N advances):",
value=str(default_output_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=40,
max_height=3,
scroll_exit=True,
)
self.nextrely -= 1
self.autoimport_dirs = {}
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
FileBox,
name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
value=str(config.root_path / config.autoimport_dir),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=32,
max_height=3,
scroll_exit=True,
)
self.nextrely += 1
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT
https://huggingface.co/spaces/CompVis/stable-diffusion-license
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSES LOCATED AT
https://huggingface.co/spaces/CompVis/stable-diffusion-license and
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md
"""
for i in textwrap.wrap(label,width=window_width-6):
for i in textwrap.wrap(label, width=window_width - 6):
self.add_widget_intelligent(
npyscreen.FixedText,
value=i,
@ -451,22 +426,17 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
)
self.license_acceptance = self.add_widget_intelligent(
npyscreen.Checkbox,
name="I accept the CreativeML Responsible AI License",
name="I accept the CreativeML Responsible AI Licenses",
value=not first_time,
relx=2,
scroll_exit=True,
)
self.nextrely += 1
label = (
"DONE"
if program_opts.skip_sd_weights or program_opts.default_only
else "NEXT"
)
label = "DONE" if program_opts.skip_sd_weights or program_opts.default_only else "NEXT"
self.ok_button = self.add_widget_intelligent(
CenteredButtonPress,
name=label,
relx=(window_width - len(label)) // 2,
rely=-3,
when_pressed_function=self.on_ok,
)
@ -481,13 +451,11 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
self.editing = False
else:
self.editing = True
def validate_field_values(self, opt: Namespace) -> bool:
bad_fields = []
if not opt.license_acceptance:
bad_fields.append(
"Please accept the license terms before proceeding to model downloads"
)
bad_fields.append("Please accept the license terms before proceeding to model downloads")
if not Path(opt.outdir).parent.exists():
bad_fields.append(
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
@ -505,12 +473,11 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
new_opts = Namespace()
for attr in [
"outdir",
"nsfw_checker",
"free_gpu_mem",
"max_cache_size",
"xformers_enabled",
"always_use_cpu",
"outdir",
"free_gpu_mem",
"max_cache_size",
"xformers_enabled",
"always_use_cpu",
]:
setattr(new_opts, attr, getattr(self, attr).value)
@ -523,7 +490,7 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
new_opts.hf_token = self.hf_token.value
new_opts.license_acceptance = self.license_acceptance.value
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
return new_opts
@ -542,7 +509,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
"MAIN",
editOptsForm,
name="InvokeAI Startup Options",
cycle_widgets=True,
cycle_widgets=False,
)
if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
self.model_select = self.addForm(
@ -550,7 +517,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
addModelsForm,
name="Install Stable Diffusion Models",
multipage=True,
cycle_widgets=True,
cycle_widgets=False,
)
def new_opts(self):
@ -562,21 +529,20 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
editApp.run()
return editApp.new_opts()
def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIAppConfig.get_config()
if not init_file.exists():
opts.nsfw_checker = True
return opts
def default_user_selections(program_opts: Namespace) -> InstallSelections:
try:
installer = ModelInstall(config)
except omegaconf.errors.ConfigKeyError:
logger.warning('Your models.yaml file is corrupt or out of date. Reinitializing')
logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
initialize_rootdir(config.root_path, True)
installer = ModelInstall(config)
models = installer.all_models()
return InstallSelections(
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
@ -586,44 +552,46 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
else list(),
)
# -------------------------------------
def initialize_rootdir(root: Path, yes_to_all: bool = False):
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
for name in (
"models",
"databases",
"text-inversion-output",
"text-inversion-training-data",
"configs"
):
logger.info("Initializing InvokeAI runtime directory")
for name in ("models", "databases", "text-inversion-output", "text-inversion-training-data", "configs"):
os.makedirs(os.path.join(root, name), exist_ok=True)
for model_type in ModelType:
Path(root, 'autoimport', model_type.value).mkdir(parents=True, exist_ok=True)
Path(root, "autoimport", model_type.value).mkdir(parents=True, exist_ok=True)
configs_src = Path(configs.__path__[0])
configs_dest = root / "configs"
if not os.path.samefile(configs_src, configs_dest):
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
dest = root / 'models'
dest = root / "models"
for model_base in BaseModelType:
for model_type in ModelType:
path = dest / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True)
path = dest / 'core'
path = dest / "core"
path.mkdir(parents=True, exist_ok=True)
with open(root / 'configs' / 'models.yaml','w') as yaml_file:
yaml_file.write(yaml.dump({'__metadata__':
{'version':'3.0.0'}
}
)
)
maybe_create_models_yaml(root)
def maybe_create_models_yaml(root: Path):
models_yaml = root / "configs" / "models.yaml"
if models_yaml.exists():
if OmegaConf.load(models_yaml).get("__metadata__"): # up to date
return
else:
logger.info("Creating new models.yaml, original saved as models.yaml.orig")
models_yaml.rename(models_yaml.parent / "models.yaml.orig")
with open(models_yaml, "w") as yaml_file:
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
# -------------------------------------
def run_console_ui(
program_opts: Namespace, initfile: Path = None
) -> (Namespace, Namespace):
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
# parse_args() will read from init file if present
invokeai_opts = default_startup_options(initfile)
invokeai_opts.root = program_opts.root
@ -635,8 +603,9 @@ def run_console_ui(
# the install-models application spawns a subprocess to install
# models, and will crash unless this is set before running.
import torch
torch.multiprocessing.set_start_method("spawn")
editApp = EditOptApplication(program_opts, invokeai_opts)
editApp.run()
if editApp.user_cancelled:
@ -653,81 +622,86 @@ def write_opts(opts: Namespace, init_file: Path):
# this will load current settings
new_config = InvokeAIAppConfig.get_config()
new_config.root = config.root
for key,value in opts.__dict__.items():
if hasattr(new_config,key):
setattr(new_config,key,value)
with open(init_file,'w', encoding='utf-8') as file:
for key, value in opts.__dict__.items():
if hasattr(new_config, key):
setattr(new_config, key, value)
with open(init_file, "w", encoding="utf-8") as file:
file.write(new_config.to_yaml())
if hasattr(opts,'hf_token') and opts.hf_token:
if hasattr(opts, "hf_token") and opts.hf_token:
HfLogin(opts.hf_token)
# -------------------------------------
def default_output_dir() -> Path:
return config.root_path / "outputs"
# -------------------------------------
def write_default_options(program_opts: Namespace, initfile: Path):
opt = default_startup_options(initfile)
write_opts(opt, initfile)
# -------------------------------------
# Here we bring in
# the legacy Args object in order to parse
# the old init file and write out the new
# yaml format.
def migrate_init_file(legacy_format:Path):
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
def migrate_init_file(legacy_format: Path):
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
new = InvokeAIAppConfig.get_config()
fields = list(get_type_hints(InvokeAIAppConfig).keys())
for attr in fields:
if hasattr(old,attr):
setattr(new,attr,getattr(old,attr))
if hasattr(old, attr):
setattr(new, attr, getattr(old, attr))
# a few places where the field names have changed and we have to
# manually add in the new names/values
new.nsfw_checker = old.safety_checker
new.xformers_enabled = old.xformers
new.conf_path = old.conf
new.root = legacy_format.parent.resolve()
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
invokeai_yaml = legacy_format.parent / "invokeai.yaml"
with open(invokeai_yaml, "w", encoding="utf-8") as outfile:
outfile.write(new.to_yaml())
legacy_format.replace(legacy_format.parent / 'invokeai.init.orig')
legacy_format.replace(legacy_format.parent / "invokeai.init.orig")
# -------------------------------------
def migrate_models(root: Path):
from invokeai.backend.install.migrate_to_3 import do_migrate
do_migrate(root, root)
def migrate_if_needed(opt: Namespace, root: Path)->bool:
# We check for to see if the runtime directory is correctly initialized.
old_init_file = root / 'invokeai.init'
new_init_file = root / 'invokeai.yaml'
old_hub = root / 'models/hub'
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
if migration_needed:
if opt.yes_to_all or \
yes_or_no(f'{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?'):
logger.info('** Migrating invokeai.init to invokeai.yaml')
def migrate_if_needed(opt: Namespace, root: Path) -> bool:
# We check for to see if the runtime directory is correctly initialized.
old_init_file = root / "invokeai.init"
new_init_file = root / "invokeai.yaml"
old_hub = root / "models/hub"
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
if migration_needed:
if opt.yes_to_all or yes_or_no(
f"{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?"
):
logger.info("** Migrating invokeai.init to invokeai.yaml")
migrate_init_file(old_init_file)
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
config.parse_args(argv=[], conf=OmegaConf.load(new_init_file))
if old_hub.exists():
migrate_models(config.root_path)
else:
print('Cannot continue without conversion. Aborting.')
print("Cannot continue without conversion. Aborting.")
return migration_needed
# -------------------------------------
def main():
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
@ -784,9 +758,9 @@ def main():
invoke_args = []
if opt.root:
invoke_args.extend(['--root',opt.root])
invoke_args.extend(["--root", opt.root])
if opt.full_precision:
invoke_args.extend(['--precision','float32'])
invoke_args.extend(["--precision", "float32"])
config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config)
@ -798,41 +772,36 @@ def main():
if migrate_if_needed(opt, config.root_path):
sys.exit(0)
if not config.model_conf_path.exists():
initialize_rootdir(config.root_path, opt.yes_to_all)
# run this unconditionally in case new directories need to be added
initialize_rootdir(config.root_path, opt.yes_to_all)
models_to_download = default_user_selections(opt)
new_init_file = config.root_path / 'invokeai.yaml'
new_init_file = config.root_path / "invokeai.yaml"
if opt.yes_to_all:
write_default_options(opt, new_init_file)
init_options = Namespace(
precision="float32" if opt.full_precision else "float16"
)
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
else:
init_options, models_to_download = run_console_ui(opt, new_init_file)
if init_options:
write_opts(init_options, new_init_file)
else:
logger.info(
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
)
logger.info('\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n')
sys.exit(0)
if opt.skip_support_models:
logger.info("SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST")
logger.info("Skipping support models at user's request")
else:
logger.info("CHECKING/UPDATING SUPPORT MODELS")
logger.info("Installing support models")
download_support_models()
if opt.skip_sd_weights:
logger.warning("SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST")
logger.warning("Skipping diffusion weights download per user request")
elif models_to_download:
logger.info("DOWNLOADING DIFFUSION WEIGHTS")
process_and_execute(opt, models_to_download)
postscript(errors=errors)
if not opt.yes_to_all:
input('Press any key to continue...')
input("Press any key to continue...")
except KeyboardInterrupt:
print("\nGoodbye! Come back soon.")

View File

@ -47,17 +47,18 @@ PRECISION_CHOICES = [
"float16",
]
class FileArgumentParser(ArgumentParser):
"""
Supports reading defaults from an init file.
"""
def convert_arg_line_to_args(self, arg_line):
return shlex.split(arg_line, comments=True)
legacy_parser = FileArgumentParser(
description=
"""
description="""
Generate images using Stable Diffusion.
Use --web to launch the web interface.
Use --from_file to load prompts from a file path or standard input ("-").
@ -65,304 +66,279 @@ Generate images using Stable Diffusion.
Other command-line arguments are defaults that can usually be overridden
prompt the command prompt.
""",
fromfile_prefix_chars='@',
fromfile_prefix_chars="@",
)
general_group = legacy_parser.add_argument_group('General')
model_group = legacy_parser.add_argument_group('Model selection')
file_group = legacy_parser.add_argument_group('Input/output')
web_server_group = legacy_parser.add_argument_group('Web server')
render_group = legacy_parser.add_argument_group('Rendering')
postprocessing_group = legacy_parser.add_argument_group('Postprocessing')
deprecated_group = legacy_parser.add_argument_group('Deprecated options')
general_group = legacy_parser.add_argument_group("General")
model_group = legacy_parser.add_argument_group("Model selection")
file_group = legacy_parser.add_argument_group("Input/output")
web_server_group = legacy_parser.add_argument_group("Web server")
render_group = legacy_parser.add_argument_group("Rendering")
postprocessing_group = legacy_parser.add_argument_group("Postprocessing")
deprecated_group = legacy_parser.add_argument_group("Deprecated options")
deprecated_group.add_argument('--laion400m')
deprecated_group.add_argument('--weights') # deprecated
general_group.add_argument(
'--version','-V',
action='store_true',
help='Print InvokeAI version number'
)
deprecated_group.add_argument("--laion400m")
deprecated_group.add_argument("--weights") # deprecated
general_group.add_argument("--version", "-V", action="store_true", help="Print InvokeAI version number")
model_group.add_argument(
'--root_dir',
"--root_dir",
default=None,
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
)
model_group.add_argument(
'--config',
'-c',
'-config',
dest='conf',
default='./configs/models.yaml',
help='Path to configuration file for alternate models.',
"--config",
"-c",
"-config",
dest="conf",
default="./configs/models.yaml",
help="Path to configuration file for alternate models.",
)
model_group.add_argument(
'--model',
"--model",
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
)
model_group.add_argument(
'--weight_dirs',
nargs='+',
"--weight_dirs",
nargs="+",
type=str,
help='List of one or more directories that will be auto-scanned for new model weights to import',
help="List of one or more directories that will be auto-scanned for new model weights to import",
)
model_group.add_argument(
'--png_compression','-z',
"--png_compression",
"-z",
type=int,
default=6,
choices=range(0,9),
dest='png_compression',
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
choices=range(0, 9),
dest="png_compression",
help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
)
model_group.add_argument(
'-F',
'--full_precision',
dest='full_precision',
action='store_true',
help='Deprecated way to set --precision=float32',
"-F",
"--full_precision",
dest="full_precision",
action="store_true",
help="Deprecated way to set --precision=float32",
)
model_group.add_argument(
'--max_loaded_models',
dest='max_loaded_models',
"--max_loaded_models",
dest="max_loaded_models",
type=int,
default=2,
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
help="Maximum number of models to keep in memory for fast switching, including the one in GPU",
)
model_group.add_argument(
'--free_gpu_mem',
dest='free_gpu_mem',
action='store_true',
help='Force free gpu memory before final decoding',
"--free_gpu_mem",
dest="free_gpu_mem",
action="store_true",
help="Force free gpu memory before final decoding",
)
model_group.add_argument(
'--sequential_guidance',
dest='sequential_guidance',
action='store_true',
help="Calculate guidance in serial instead of in parallel, lowering memory requirement "
"at the expense of speed",
"--sequential_guidance",
dest="sequential_guidance",
action="store_true",
help="Calculate guidance in serial instead of in parallel, lowering memory requirement " "at the expense of speed",
)
model_group.add_argument(
'--xformers',
"--xformers",
action=argparse.BooleanOptionalAction,
default=True,
help='Enable/disable xformers support (default enabled if installed)',
help="Enable/disable xformers support (default enabled if installed)",
)
model_group.add_argument(
"--always_use_cpu",
dest="always_use_cpu",
action="store_true",
help="Force use of CPU even if GPU is available"
"--always_use_cpu", dest="always_use_cpu", action="store_true", help="Force use of CPU even if GPU is available"
)
model_group.add_argument(
'--precision',
dest='precision',
"--precision",
dest="precision",
type=str,
choices=PRECISION_CHOICES,
metavar='PRECISION',
metavar="PRECISION",
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto',
default="auto",
)
model_group.add_argument(
'--ckpt_convert',
"--ckpt_convert",
action=argparse.BooleanOptionalAction,
dest='ckpt_convert',
dest="ckpt_convert",
default=True,
help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.'
help="Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.",
)
model_group.add_argument(
'--internet',
"--internet",
action=argparse.BooleanOptionalAction,
dest='internet_available',
dest="internet_available",
default=True,
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).',
help="Indicate whether internet is available for just-in-time model downloading (default: probe automatically).",
)
model_group.add_argument(
'--nsfw_checker',
'--safety_checker',
"--nsfw_checker",
"--safety_checker",
action=argparse.BooleanOptionalAction,
dest='safety_checker',
dest="safety_checker",
default=False,
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.",
)
model_group.add_argument(
'--autoimport',
"--autoimport",
default=None,
type=str,
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
)
model_group.add_argument(
'--autoconvert',
"--autoconvert",
default=None,
type=str,
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models",
)
model_group.add_argument(
'--patchmatch',
"--patchmatch",
action=argparse.BooleanOptionalAction,
default=True,
help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.',
help="Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.",
)
file_group.add_argument(
'--from_file',
dest='infile',
"--from_file",
dest="infile",
type=str,
help='If specified, load prompts from this file',
help="If specified, load prompts from this file",
)
file_group.add_argument(
'--outdir',
'-o',
"--outdir",
"-o",
type=str,
help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs',
default='outputs',
help="Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs",
default="outputs",
)
file_group.add_argument(
'--prompt_as_dir',
'-p',
action='store_true',
help='Place images in subdirectories named after the prompt.',
"--prompt_as_dir",
"-p",
action="store_true",
help="Place images in subdirectories named after the prompt.",
)
render_group.add_argument(
'--fnformat',
default='{prefix}.{seed}.png',
"--fnformat",
default="{prefix}.{seed}.png",
type=str,
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
help="Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png",
)
render_group.add_argument("-s", "--steps", type=int, default=50, help="Number of steps")
render_group.add_argument(
'-s',
'--steps',
"-W",
"--width",
type=int,
default=50,
help='Number of steps'
help="Image width, multiple of 64",
)
render_group.add_argument(
'-W',
'--width',
"-H",
"--height",
type=int,
help='Image width, multiple of 64',
help="Image height, multiple of 64",
)
render_group.add_argument(
'-H',
'--height',
type=int,
help='Image height, multiple of 64',
)
render_group.add_argument(
'-C',
'--cfg_scale',
"-C",
"--cfg_scale",
default=7.5,
type=float,
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
)
render_group.add_argument(
'--sampler',
'-A',
'-m',
dest='sampler_name',
"--sampler",
"-A",
"-m",
dest="sampler_name",
type=str,
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
metavar="SAMPLER_NAME",
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
default='k_lms',
default="k_lms",
)
render_group.add_argument(
'--log_tokenization',
'-t',
action='store_true',
help='shows how the prompt is split into tokens'
"--log_tokenization", "-t", action="store_true", help="shows how the prompt is split into tokens"
)
render_group.add_argument(
'-f',
'--strength',
"-f",
"--strength",
type=float,
help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
help="img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely",
)
render_group.add_argument(
'-T',
'-fit',
'--fit',
"-T",
"-fit",
"--fit",
action=argparse.BooleanOptionalAction,
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
help="If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)",
)
render_group.add_argument("--grid", "-g", action=argparse.BooleanOptionalAction, help="generate a grid")
render_group.add_argument(
'--grid',
'-g',
action=argparse.BooleanOptionalAction,
help='generate a grid'
)
render_group.add_argument(
'--embedding_directory',
'--embedding_path',
dest='embedding_path',
default='embeddings',
"--embedding_directory",
"--embedding_path",
dest="embedding_path",
default="embeddings",
type=str,
help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)'
help="Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)",
)
render_group.add_argument(
'--lora_directory',
dest='lora_path',
default='loras',
"--lora_directory",
dest="lora_path",
default="loras",
type=str,
help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)'
help="Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)",
)
render_group.add_argument(
'--embeddings',
"--embeddings",
action=argparse.BooleanOptionalAction,
default=True,
help='Enable embedding directory (default). Use --no-embeddings to disable.',
help="Enable embedding directory (default). Use --no-embeddings to disable.",
)
render_group.add_argument("--enable_image_debugging", action="store_true", help="Generates debugging image to display")
render_group.add_argument(
'--enable_image_debugging',
action='store_true',
help='Generates debugging image to display'
)
render_group.add_argument(
'--karras_max',
"--karras_max",
type=int,
default=None,
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29].",
)
# Restoration related args
postprocessing_group.add_argument(
'--no_restore',
dest='restore',
action='store_false',
help='Disable face restoration with GFPGAN or codeformer',
"--no_restore",
dest="restore",
action="store_false",
help="Disable face restoration with GFPGAN or codeformer",
)
postprocessing_group.add_argument(
'--no_upscale',
dest='esrgan',
action='store_false',
help='Disable upscaling with ESRGAN',
"--no_upscale",
dest="esrgan",
action="store_false",
help="Disable upscaling with ESRGAN",
)
postprocessing_group.add_argument(
'--esrgan_bg_tile',
"--esrgan_bg_tile",
type=int,
default=400,
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
help="Tile size for background sampler, 0 for no tile during testing. Default: 400.",
)
postprocessing_group.add_argument(
'--esrgan_denoise_str',
"--esrgan_denoise_str",
type=float,
default=0.75,
help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75',
help="esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75",
)
postprocessing_group.add_argument(
'--gfpgan_model_path',
"--gfpgan_model_path",
type=str,
default='./models/gfpgan/GFPGANv1.4.pth',
help='Indicates the path to the GFPGAN model',
default="./models/gfpgan/GFPGANv1.4.pth",
help="Indicates the path to the GFPGAN model",
)
web_server_group.add_argument(
'--web',
dest='web',
action='store_true',
help='Start in web server mode.',
"--web",
dest="web",
action="store_true",
help="Start in web server mode.",
)
web_server_group.add_argument(
'--web_develop',
dest='web_develop',
action='store_true',
help='Start in web server development mode.',
"--web_develop",
dest="web_develop",
action="store_true",
help="Start in web server development mode.",
)
web_server_group.add_argument(
"--web_verbose",
@ -376,32 +352,27 @@ web_server_group.add_argument(
help="Additional allowed origins, comma-separated",
)
web_server_group.add_argument(
'--host',
"--host",
type=str,
default='127.0.0.1',
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
default="127.0.0.1",
help="Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.",
)
web_server_group.add_argument("--port", type=int, default="9090", help="Web server: Port to listen on")
web_server_group.add_argument(
'--port',
type=int,
default='9090',
help='Web server: Port to listen on'
)
web_server_group.add_argument(
'--certfile',
"--certfile",
type=str,
default=None,
help='Web server: Path to certificate file to use for SSL. Use together with --keyfile'
help="Web server: Path to certificate file to use for SSL. Use together with --keyfile",
)
web_server_group.add_argument(
'--keyfile',
"--keyfile",
type=str,
default=None,
help='Web server: Path to private key file to use for SSL. Use together with --certfile'
help="Web server: Path to private key file to use for SSL. Use together with --certfile",
)
web_server_group.add_argument(
'--gui',
dest='gui',
action='store_true',
help='Start InvokeAI GUI',
"--gui",
dest="gui",
action="store_true",
help="Start InvokeAI GUI",
)

View File

@ -1,7 +1,7 @@
'''
"""
Migrate the models directory and models.yaml file from an existing
InvokeAI 2.3 installation to 3.0.0.
'''
"""
import os
import argparse
@ -29,14 +29,13 @@ from transformers import (
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager
from invokeai.backend.model_management.model_probe import (
ModelProbe, ModelType, BaseModelType, ModelProbeInfo
)
from invokeai.backend.model_management.model_probe import ModelProbe, ModelType, BaseModelType, ModelProbeInfo
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
diffusers.logging.set_verbosity_error()
# holder for paths that we will migrate
@dataclass
class ModelPaths:
@ -45,81 +44,82 @@ class ModelPaths:
loras: Path
controlnets: Path
class MigrateTo3(object):
def __init__(self,
from_root: Path,
to_models: Path,
model_manager: ModelManager,
src_paths: ModelPaths,
):
def __init__(
self,
from_root: Path,
to_models: Path,
model_manager: ModelManager,
src_paths: ModelPaths,
):
self.root_directory = from_root
self.dest_models = to_models
self.mgr = model_manager
self.src_paths = src_paths
@classmethod
def initialize_yaml(cls, yaml_file: Path):
with open(yaml_file, 'w') as file:
file.write(
yaml.dump(
{
'__metadata__': {'version':'3.0.0'}
}
)
)
with open(yaml_file, "w") as file:
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
def create_directory_structure(self):
'''
"""
Create the basic directory structure for the models folder.
'''
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
for model_type in [ModelType.Main, ModelType.Vae, ModelType.Lora,
ModelType.ControlNet,ModelType.TextualInversion]:
"""
for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
for model_type in [
ModelType.Main,
ModelType.Vae,
ModelType.Lora,
ModelType.ControlNet,
ModelType.TextualInversion,
]:
path = self.dest_models / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True)
path = self.dest_models / 'core'
path = self.dest_models / "core"
path.mkdir(parents=True, exist_ok=True)
@staticmethod
def copy_file(src:Path,dest:Path):
'''
def copy_file(src: Path, dest: Path):
"""
copy a single file with logging
'''
"""
if dest.exists():
logger.info(f'Skipping existing {str(dest)}')
logger.info(f"Skipping existing {str(dest)}")
return
logger.info(f'Copying {str(src)} to {str(dest)}')
logger.info(f"Copying {str(src)} to {str(dest)}")
try:
shutil.copy(src, dest)
except Exception as e:
logger.error(f'COPY FAILED: {str(e)}')
logger.error(f"COPY FAILED: {str(e)}")
@staticmethod
def copy_dir(src:Path,dest:Path):
'''
def copy_dir(src: Path, dest: Path):
"""
Recursively copy a directory with logging
'''
"""
if dest.exists():
logger.info(f'Skipping existing {str(dest)}')
logger.info(f"Skipping existing {str(dest)}")
return
logger.info(f'Copying {str(src)} to {str(dest)}')
logger.info(f"Copying {str(src)} to {str(dest)}")
try:
shutil.copytree(src, dest)
except Exception as e:
logger.error(f'COPY FAILED: {str(e)}')
logger.error(f"COPY FAILED: {str(e)}")
def migrate_models(self, src_dir: Path):
'''
"""
Recursively walk through src directory, probe anything
that looks like a model, and copy the model into the
appropriate location within the destination models directory.
'''
"""
directories_scanned = set()
for root, dirs, files in os.walk(src_dir):
for d in dirs:
try:
model = Path(root,d)
model = Path(root, d)
info = ModelProbe().heuristic_probe(model)
if not info:
continue
@ -136,9 +136,9 @@ class MigrateTo3(object):
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
# let them be copied as part of a tree copy operation
try:
if f in {'learned_embeds.bin','pytorch_lora_weights.bin'}:
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
continue
model = Path(root,f)
model = Path(root, f)
if model.parent in directories_scanned:
continue
info = ModelProbe().heuristic_probe(model)
@ -154,148 +154,146 @@ class MigrateTo3(object):
logger.error(str(e))
def migrate_support_models(self):
'''
"""
Copy the clipseg, upscaler, and restoration models to their new
locations.
'''
"""
dest_directory = self.dest_models
if (self.root_directory / 'models/clipseg').exists():
self.copy_dir(self.root_directory / 'models/clipseg', dest_directory / 'core/misc/clipseg')
if (self.root_directory / 'models/realesrgan').exists():
self.copy_dir(self.root_directory / 'models/realesrgan', dest_directory / 'core/upscaling/realesrgan')
for d in ['codeformer','gfpgan']:
path = self.root_directory / 'models' / d
if (self.root_directory / "models/clipseg").exists():
self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg")
if (self.root_directory / "models/realesrgan").exists():
self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan")
for d in ["codeformer", "gfpgan"]:
path = self.root_directory / "models" / d
if path.exists():
self.copy_dir(path,dest_directory / f'core/face_restoration/{d}')
self.copy_dir(path, dest_directory / f"core/face_restoration/{d}")
def migrate_tuning_models(self):
'''
"""
Migrate the embeddings, loras and controlnets directories to their new homes.
'''
"""
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
if not src:
continue
if src.is_dir():
logger.info(f'Scanning {src}')
logger.info(f"Scanning {src}")
self.migrate_models(src)
else:
logger.info(f'{src} directory not found; skipping')
logger.info(f"{src} directory not found; skipping")
continue
def migrate_conversion_models(self):
'''
"""
Migrate all the models that are needed by the ckpt_to_diffusers conversion
script.
'''
"""
dest_directory = self.dest_models
kwargs = dict(
cache_dir = self.root_directory / 'models/hub',
#local_files_only = True
cache_dir=self.root_directory / "models/hub",
# local_files_only = True
)
try:
logger.info('Migrating core tokenizers and text encoders')
target_dir = dest_directory / 'core' / 'convert'
logger.info("Migrating core tokenizers and text encoders")
target_dir = dest_directory / "core" / "convert"
self._migrate_pretrained(BertTokenizerFast,
repo_id='bert-base-uncased',
dest = target_dir / 'bert-base-uncased',
**kwargs)
self._migrate_pretrained(
BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs
)
# sd-1
repo_id = 'openai/clip-vit-large-patch14'
self._migrate_pretrained(CLIPTokenizer,
repo_id= repo_id,
dest= target_dir / 'clip-vit-large-patch14',
**kwargs)
self._migrate_pretrained(CLIPTextModel,
repo_id = repo_id,
dest = target_dir / 'clip-vit-large-patch14',
force = True,
**kwargs)
repo_id = "openai/clip-vit-large-patch14"
self._migrate_pretrained(
CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs
)
self._migrate_pretrained(
CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs
)
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
self._migrate_pretrained(CLIPTokenizer,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer',
**{'subfolder':'tokenizer',**kwargs}
)
self._migrate_pretrained(CLIPTextModel,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder',
**{'subfolder':'text_encoder',**kwargs}
)
self._migrate_pretrained(
CLIPTokenizer,
repo_id=repo_id,
dest=target_dir / "stable-diffusion-2-clip" / "tokenizer",
**{"subfolder": "tokenizer", **kwargs},
)
self._migrate_pretrained(
CLIPTextModel,
repo_id=repo_id,
dest=target_dir / "stable-diffusion-2-clip" / "text_encoder",
**{"subfolder": "text_encoder", **kwargs},
)
# VAE
logger.info('Migrating stable diffusion VAE')
self._migrate_pretrained(AutoencoderKL,
repo_id = 'stabilityai/sd-vae-ft-mse',
dest = target_dir / 'sd-vae-ft-mse',
**kwargs)
logger.info("Migrating stable diffusion VAE")
self._migrate_pretrained(
AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs
)
# safety checking
logger.info('Migrating safety checker')
logger.info("Migrating safety checker")
repo_id = "CompVis/stable-diffusion-safety-checker"
self._migrate_pretrained(AutoFeatureExtractor,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-safety-checker',
**kwargs)
self._migrate_pretrained(StableDiffusionSafetyChecker,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-safety-checker',
**kwargs)
self._migrate_pretrained(
AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs
)
self._migrate_pretrained(
StableDiffusionSafetyChecker,
repo_id=repo_id,
dest=target_dir / "stable-diffusion-safety-checker",
**kwargs,
)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
def _model_probe_to_path(self, info: ModelProbeInfo) -> Path:
return Path(self.dest_models, info.base_type.value, info.model_type.value)
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs):
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs):
if dest.exists() and not force:
logger.info(f'Skipping existing {dest}')
logger.info(f"Skipping existing {dest}")
return
model = model_class.from_pretrained(repo_id, **kwargs)
self._save_pretrained(model, dest, overwrite=force)
def _save_pretrained(self, model, dest: Path, overwrite: bool=False):
def _save_pretrained(self, model, dest: Path, overwrite: bool = False):
model_name = dest.name
if overwrite:
model.save_pretrained(dest, safe_serialization=True)
else:
download_path = dest.with_name(f'{model_name}.downloading')
download_path = dest.with_name(f"{model_name}.downloading")
model.save_pretrained(download_path, safe_serialization=True)
download_path.replace(dest)
def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
info = ModelProbe().heuristic_probe(vae)
_, model_name = repo_id.split('/')
_, model_name = repo_id.split("/")
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
vae.save_pretrained(dest, safe_serialization=True)
return dest
def _vae_path(self, vae: Union[str,dict])->Path:
'''
def _vae_path(self, vae: Union[str, dict]) -> Path:
"""
Convert 2.3 VAE stanza to a straight path.
'''
"""
vae_path = None
# First get a path
if isinstance(vae,str):
if isinstance(vae, str):
vae_path = vae
elif isinstance(vae,DictConfig):
if p := vae.get('path'):
elif isinstance(vae, DictConfig):
if p := vae.get("path"):
vae_path = p
elif repo_id := vae.get('repo_id'):
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
vae_path = 'models/core/convert/sd-vae-ft-mse'
elif repo_id := vae.get("repo_id"):
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
vae_path = "models/core/convert/sd-vae-ft-mse"
return vae_path
else:
vae_path = self._download_vae(repo_id, vae.get('subfolder'))
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
assert vae_path is not None, "Couldn't find VAE for this model"
@ -307,152 +305,144 @@ class MigrateTo3(object):
dest = self._model_probe_to_path(info) / vae_path.name
if not dest.exists():
if vae_path.is_dir():
self.copy_dir(vae_path,dest)
self.copy_dir(vae_path, dest)
else:
self.copy_file(vae_path,dest)
self.copy_file(vae_path, dest)
vae_path = dest
if vae_path.is_relative_to(self.dest_models):
rel_path = vae_path.relative_to(self.dest_models)
return Path('models',rel_path)
return Path("models", rel_path)
else:
return vae_path
def migrate_repo_id(self, repo_id: str, model_name: str=None, **extra_config):
'''
def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config):
"""
Migrate a locally-cached diffusers pipeline identified with a repo_id
'''
"""
dest_dir = self.dest_models
cache = self.root_directory / 'models/hub'
cache = self.root_directory / "models/hub"
kwargs = dict(
cache_dir = cache,
safety_checker = None,
cache_dir=cache,
safety_checker=None,
# local_files_only = True,
)
owner,repo_name = repo_id.split('/')
owner, repo_name = repo_id.split("/")
model_name = model_name or repo_name
model = cache / '--'.join(['models',owner,repo_name])
if len(list(model.glob('snapshots/**/model_index.json')))==0:
model = cache / "--".join(["models", owner, repo_name])
if len(list(model.glob("snapshots/**/model_index.json"))) == 0:
return
revisions = [x.name for x in model.glob('refs/*')]
revisions = [x.name for x in model.glob("refs/*")]
# if an fp16 is available we use that
revision = 'fp16' if len(revisions) > 1 and 'fp16' in revisions else revisions[0]
pipeline = StableDiffusionPipeline.from_pretrained(
repo_id,
revision=revision,
**kwargs)
revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0]
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs)
info = ModelProbe().heuristic_probe(pipeline)
if not info:
return
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
return
dest = self._model_probe_to_path(info) / model_name
self._save_pretrained(pipeline, dest)
rel_path = Path('models',dest.relative_to(dest_dir))
rel_path = Path("models", dest.relative_to(dest_dir))
self._add_model(model_name, info, rel_path, **extra_config)
def migrate_path(self, location: Path, model_name: str=None, **extra_config):
'''
def migrate_path(self, location: Path, model_name: str = None, **extra_config):
"""
Migrate a model referred to using 'weights' or 'path'
'''
"""
# handle relative paths
dest_dir = self.dest_models
location = self.root_directory / location
model_name = model_name or location.stem
info = ModelProbe().heuristic_probe(location)
if not info:
return
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
return
# uh oh, weights is in the old models directory - move it into the new one
if Path(location).is_relative_to(self.src_paths.models):
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
if location.is_dir():
self.copy_dir(location,dest)
self.copy_dir(location, dest)
else:
self.copy_file(location,dest)
location = Path('models', info.base_type.value, info.model_type.value, location.name)
self.copy_file(location, dest)
location = Path("models", info.base_type.value, info.model_type.value, location.name)
self._add_model(model_name, info, location, **extra_config)
def _add_model(self,
model_name: str,
info: ModelProbeInfo,
location: Path,
**extra_config):
def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config):
if info.model_type != ModelType.Main:
return
self.mgr.add_model(
model_name = model_name,
base_model = info.base_type,
model_type = info.model_type,
clobber = True,
model_attributes = {
'path': str(location),
'description': f'A {info.base_type.value} {info.model_type.value} model',
'model_format': info.format,
'variant': info.variant_type.value,
**extra_config,
}
)
def migrate_defined_models(self):
'''
Migrate models defined in models.yaml
'''
# find any models referred to in old models.yaml
conf = OmegaConf.load(self.root_directory / 'configs/models.yaml')
for model_name, stanza in conf.items():
self.mgr.add_model(
model_name=model_name,
base_model=info.base_type,
model_type=info.model_type,
clobber=True,
model_attributes={
"path": str(location),
"description": f"A {info.base_type.value} {info.model_type.value} model",
"model_format": info.format,
"variant": info.variant_type.value,
**extra_config,
},
)
def migrate_defined_models(self):
"""
Migrate models defined in models.yaml
"""
# find any models referred to in old models.yaml
conf = OmegaConf.load(self.root_directory / "configs/models.yaml")
for model_name, stanza in conf.items():
try:
passthru_args = {}
if vae := stanza.get('vae'):
if vae := stanza.get("vae"):
try:
passthru_args['vae'] = str(self._vae_path(vae))
passthru_args["vae"] = str(self._vae_path(vae))
except Exception as e:
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
logger.warning(str(e))
if config := stanza.get('config'):
passthru_args['config'] = config
if config := stanza.get("config"):
passthru_args["config"] = config
if description:= stanza.get('description'):
passthru_args['description'] = description
if repo_id := stanza.get('repo_id'):
logger.info(f'Migrating diffusers model {model_name}')
if description := stanza.get("description"):
passthru_args["description"] = description
if repo_id := stanza.get("repo_id"):
logger.info(f"Migrating diffusers model {model_name}")
self.migrate_repo_id(repo_id, model_name, **passthru_args)
elif location := stanza.get('weights'):
logger.info(f'Migrating checkpoint model {model_name}')
elif location := stanza.get("weights"):
logger.info(f"Migrating checkpoint model {model_name}")
self.migrate_path(Path(location), model_name, **passthru_args)
elif location := stanza.get('path'):
logger.info(f'Migrating diffusers model {model_name}')
elif location := stanza.get("path"):
logger.info(f"Migrating diffusers model {model_name}")
self.migrate_path(Path(location), model_name, **passthru_args)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate(self):
self.create_directory_structure()
# the configure script is doing this
@ -461,67 +451,71 @@ class MigrateTo3(object):
self.migrate_tuning_models()
self.migrate_defined_models()
def _parse_legacy_initfile(root: Path, initfile: Path)->ModelPaths:
'''
def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths:
"""
Returns tuple of (embedding_path, lora_path, controlnet_path)
'''
parser = argparse.ArgumentParser(fromfile_prefix_chars='@')
"""
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
parser.add_argument(
'--embedding_directory',
'--embedding_path',
"--embedding_directory",
"--embedding_path",
type=Path,
dest='embedding_path',
default=Path('embeddings'),
dest="embedding_path",
default=Path("embeddings"),
)
parser.add_argument(
'--lora_directory',
dest='lora_path',
"--lora_directory",
dest="lora_path",
type=Path,
default=Path('loras'),
default=Path("loras"),
)
opt,_ = parser.parse_known_args([f'@{str(initfile)}'])
opt, _ = parser.parse_known_args([f"@{str(initfile)}"])
return ModelPaths(
models = root / 'models',
embeddings = root / str(opt.embedding_path).strip('"'),
loras = root / str(opt.lora_path).strip('"'),
controlnets = root / 'controlnets',
models=root / "models",
embeddings=root / str(opt.embedding_path).strip('"'),
loras=root / str(opt.lora_path).strip('"'),
controlnets=root / "controlnets",
)
def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths:
'''
def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths:
"""
Returns tuple of (embedding_path, lora_path, controlnet_path)
'''
"""
# Don't use the config object because it is unforgiving of version updates
# Just use omegaconf directly
opt = OmegaConf.load(initfile)
paths = opt.InvokeAI.Paths
models = paths.get('models_dir','models')
embeddings = paths.get('embedding_dir','embeddings')
loras = paths.get('lora_dir','loras')
controlnets = paths.get('controlnet_dir','controlnets')
models = paths.get("models_dir", "models")
embeddings = paths.get("embedding_dir", "embeddings")
loras = paths.get("lora_dir", "loras")
controlnets = paths.get("controlnet_dir", "controlnets")
return ModelPaths(
models = root / models,
embeddings = root / embeddings,
loras = root /loras,
controlnets = root / controlnets,
models=root / models,
embeddings=root / embeddings,
loras=root / loras,
controlnets=root / controlnets,
)
def get_legacy_embeddings(root: Path) -> ModelPaths:
path = root / 'invokeai.init'
path = root / "invokeai.init"
if path.exists():
return _parse_legacy_initfile(root, path)
path = root / 'invokeai.yaml'
path = root / "invokeai.yaml"
if path.exists():
return _parse_legacy_yamlfile(root, path)
def do_migrate(src_directory: Path, dest_directory: Path):
"""
Migrate models from src to dest InvokeAI root directories
"""
config_file = dest_directory / 'configs' / 'models.yaml.3'
dest_models = dest_directory / 'models.3'
version_3 = (dest_directory / 'models' / 'core').exists()
config_file = dest_directory / "configs" / "models.yaml.3"
dest_models = dest_directory / "models.3"
version_3 = (dest_directory / "models" / "core").exists()
# Here we create the destination models.yaml file.
# If we are writing into a version 3 directory and the
@ -530,80 +524,80 @@ def do_migrate(src_directory: Path, dest_directory: Path):
# create a new empty one.
if version_3: # write into the dest directory
try:
shutil.copy(dest_directory / 'configs' / 'models.yaml', config_file)
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
except:
MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
(dest_directory / 'models').replace(dest_models)
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
(dest_directory / "models").replace(dest_models)
else:
MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file)
paths = get_legacy_embeddings(src_directory)
migrator = MigrateTo3(
from_root = src_directory,
to_models = dest_models,
model_manager = mgr,
src_paths = paths
)
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
migrator.migrate()
print("Migration successful.")
if not version_3:
(dest_directory / 'models').replace(src_directory / 'models.orig')
print(f'Original models directory moved to {dest_directory}/models.orig')
(dest_directory / 'configs' / 'models.yaml').replace(src_directory / 'configs' / 'models.yaml.orig')
print(f'Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig')
config_file.replace(config_file.with_suffix(''))
dest_models.replace(dest_models.with_suffix(''))
(dest_directory / "models").replace(src_directory / "models.orig")
print(f"Original models directory moved to {dest_directory}/models.orig")
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
config_file.replace(config_file.with_suffix(""))
dest_models.replace(dest_models.with_suffix(""))
def main():
parser = argparse.ArgumentParser(prog="invokeai-migrate3",
description="""
parser = argparse.ArgumentParser(
prog="invokeai-migrate3",
description="""
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
script, which will perform a full upgrade in place."""
)
parser.add_argument('--from-directory',
dest='src_root',
type=Path,
required=True,
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")'
)
parser.add_argument('--to-directory',
dest='dest_root',
type=Path,
required=True,
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
)
script, which will perform a full upgrade in place.""",
)
parser.add_argument(
"--from-directory",
dest="src_root",
type=Path,
required=True,
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")',
)
parser.add_argument(
"--to-directory",
dest="dest_root",
type=Path,
required=True,
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")',
)
args = parser.parse_args()
src_root = args.src_root
assert src_root.is_dir(), f"{src_root} is not a valid directory"
assert (src_root / 'models').is_dir(), f"{src_root} does not contain a 'models' subdirectory"
assert (src_root / 'models' / 'hub').exists(), f"{src_root} does not contain a version 2.3 models directory"
assert (src_root / 'invokeai.init').exists() or (src_root / 'invokeai.yaml').exists(), f"{src_root} does not contain an InvokeAI init file."
assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory"
assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory"
assert (src_root / "invokeai.init").exists() or (
src_root / "invokeai.yaml"
).exists(), f"{src_root} does not contain an InvokeAI init file."
dest_root = args.dest_root
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
config = InvokeAIAppConfig.get_config()
config.parse_args(['--root',str(dest_root)])
config.parse_args(["--root", str(dest_root)])
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
dest_is_setup = (dest_root / 'models/core').exists() and (dest_root / 'databases').exists()
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
if not dest_is_setup:
import invokeai.frontend.install.invokeai_configure
from invokeai.backend.install.invokeai_configure import initialize_rootdir
initialize_rootdir(dest_root, True)
do_migrate(src_root,dest_root)
do_migrate(src_root, dest_root)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -4,7 +4,7 @@ Utility (backend) functions used by model_install.py
import os
import shutil
import warnings
from dataclasses import dataclass,field
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Dict, Callable, Union, Set
@ -28,7 +28,7 @@ warnings.filterwarnings("ignore")
# --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.getLogger(name='InvokeAI')
logger = InvokeAILogger.getLogger(name="InvokeAI")
# the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
@ -45,51 +45,63 @@ Config_preamble = """
LEGACY_CONFIGS = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: 'v1-inference.yaml',
ModelVariantType.Inpaint: 'v1-inpainting-inference.yaml',
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: 'v2-inference.yaml',
SchedulerPredictionType.VPrediction: 'v2-inference-v.yaml',
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
}
}
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
}
@dataclass
class ModelInstallList:
'''Class for listing models to be installed/removed'''
"""Class for listing models to be installed/removed"""
install_models: List[str] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list)
@dataclass
class InstallSelections():
install_models: List[str]= field(default_factory=list)
remove_models: List[str]=field(default_factory=list)
@dataclass
class ModelLoadInfo():
class InstallSelections:
install_models: List[str] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list)
@dataclass
class ModelLoadInfo:
name: str
model_type: ModelType
base_type: BaseModelType
path: Path = None
repo_id: str = None
description: str = ''
description: str = ""
installed: bool = False
recommended: bool = False
default: bool = False
class ModelInstall(object):
def __init__(self,
config:InvokeAIAppConfig,
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
model_manager: ModelManager = None,
access_token:str = None):
def __init__(
self,
config: InvokeAIAppConfig,
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
model_manager: ModelManager = None,
access_token: str = None,
):
self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path)
@ -97,65 +109,66 @@ class ModelInstall(object):
self.access_token = access_token or HfFolder.get_token()
self.reverse_paths = self._reverse_paths(self.datasets)
def all_models(self)->Dict[str,ModelLoadInfo]:
'''
def all_models(self) -> Dict[str, ModelLoadInfo]:
"""
Return dict of model_key=>ModelLoadInfo objects.
This method consolidates and simplifies the entries in both
models.yaml and INITIAL_MODELS.yaml so that they can
be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat.
'''
"""
model_dict = dict()
# first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items():
name,base,model_type = ModelManager.parse_key(key)
value['name'] = name
value['base_type'] = base
value['model_type'] = model_type
name, base, model_type = ModelManager.parse_key(key)
value["name"] = name
value["base_type"] = base
value["model_type"] = model_type
model_dict[key] = ModelLoadInfo(**value)
# supplement with entries in models.yaml
installed_models = self.mgr.list_models()
for md in installed_models:
base = md['base_model']
model_type = md['model_type']
name = md['model_name']
base = md["base_model"]
model_type = md["model_type"]
name = md["model_name"]
key = ModelManager.create_key(name, base, model_type)
if key in model_dict:
model_dict[key].installed = True
else:
model_dict[key] = ModelLoadInfo(
name = name,
base_type = base,
model_type = model_type,
path = value.get('path'),
installed = True,
name=name,
base_type=base,
model_type=model_type,
path=value.get("path"),
installed=True,
)
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
def list_models(self, model_type):
installed = self.mgr.list_models(model_type=model_type)
print(f'Installed models of type `{model_type}`:')
print(f"Installed models of type `{model_type}`:")
for i in installed:
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
def starter_models(self)->Set[str]:
# logic here a little reversed to maintain backward compatibility
def starter_models(self, all_models: bool = False) -> Set[str]:
models = set()
for key, value in self.datasets.items():
name,base,model_type = ModelManager.parse_key(key)
if model_type==ModelType.Main:
name, base, model_type = ModelManager.parse_key(key)
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
models.add(key)
return models
def recommended_models(self)->Set[str]:
def recommended_models(self) -> Set[str]:
starters = self.starter_models(all_models=True)
return set([x for x in starters if self.datasets[x].get("recommended", False)])
def default_model(self) -> str:
starters = self.starter_models()
return set([x for x in starters if self.datasets[x].get('recommended',False)])
def default_model(self)->str:
starters = self.starter_models()
defaults = [x for x in starters if self.datasets[x].get('default',False)]
defaults = [x for x in starters if self.datasets[x].get("default", False)]
return defaults[0]
def install(self, selections: InstallSelections):
@ -164,54 +177,57 @@ class ModelInstall(object):
job = 1
jobs = len(selections.remove_models) + len(selections.install_models)
# remove requested models
for key in selections.remove_models:
name,base,mtype = self.mgr.parse_key(key)
logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]')
name, base, mtype = self.mgr.parse_key(key)
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
try:
self.mgr.del_model(name,base,mtype)
self.mgr.del_model(name, base, mtype)
except FileNotFoundError as e:
logger.warning(e)
job += 1
# add requested models
for path in selections.install_models:
logger.info(f'Installing {path} [{job}/{jobs}]')
logger.info(f"Installing {path} [{job}/{jobs}]")
try:
self.heuristic_import(path)
except (ValueError, KeyError) as e:
logger.error(str(e))
job += 1
dlogging.set_verbosity(verbosity)
self.mgr.commit()
def heuristic_import(self,
model_path_id_or_url: Union[str,Path],
models_installed: Set[Path]=None,
)->Dict[str, AddModelResult]:
'''
def heuristic_import(
self,
model_path_id_or_url: Union[str, Path],
models_installed: Set[Path] = None,
) -> Dict[str, AddModelResult]:
"""
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
'''
"""
if not models_installed:
models_installed = dict()
# A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url
path = Path(model_path_id_or_url)
# checkpoint file, or similar
if path.is_file():
models_installed.update({str(path):self._install_path(path)})
models_installed.update({str(path): self._install_path(path)})
# folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in \
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
]
):
elif path.is_dir() and any(
[
(path / x).exists()
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
]
):
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
# recursive scan
@ -220,7 +236,7 @@ class ModelInstall(object):
self.heuristic_import(child, models_installed=models_installed)
# huggingface repo
elif len(str(model_path_id_or_url).split('/')) == 2:
elif len(str(model_path_id_or_url).split("/")) == 2:
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# a URL
@ -228,42 +244,43 @@ class ModelInstall(object):
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
else:
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
return models_installed
# install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
if not info:
logger.warning(f'Unable to parse format of {path}')
logger.warning(f"Unable to parse format of {path}")
return None
model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path,info)
return self.mgr.add_model(model_name = model_name,
base_model = info.base_type,
model_type = info.model_type,
model_attributes = attributes,
)
attributes = self._make_attributes(path, info)
return self.mgr.add_model(
model_name=model_name,
base_model=info.base_type,
model_type=info.model_type,
model_attributes=attributes,
)
def _install_url(self, url: str)->AddModelResult:
def _install_url(self, url: str) -> AddModelResult:
with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging))
location = download_with_resume(url, Path(staging))
if not location:
logger.error(f'Unable to download {url}. Skipping.')
logger.error(f"Unable to download {url}. Skipping.")
info = ModelProbe().heuristic_probe(location)
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
models_path = shutil.move(location,dest)
models_path = shutil.move(location, dest)
# staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str)->AddModelResult:
def _install_repo(self, repo_id: str) -> AddModelResult:
hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically
# list all the files in the repo
files = [x.rfilename for x in hinfo.siblings]
@ -271,42 +288,49 @@ class ModelInstall(object):
with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging)
if 'model_index.json' in files:
location = self._download_hf_pipeline(repo_id, staging) # pipeline
if "model_index.json" in files:
location = self._download_hf_pipeline(repo_id, staging) # pipeline
else:
for suffix in ['safetensors','bin']:
if f'pytorch_lora_weights.{suffix}' in files:
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA
for suffix in ["safetensors", "bin"]:
if f"pytorch_lora_weights.{suffix}" in files:
location = self._download_hf_model(repo_id, ["pytorch_lora_weights.bin"], staging) # LoRA
break
elif self.config.precision=='float16' and f'diffusion_pytorch_model.fp16.{suffix}' in files: # vae, controlnet or some other standalone
files = ['config.json', f'diffusion_pytorch_model.fp16.{suffix}']
elif (
self.config.precision == "float16" and f"diffusion_pytorch_model.fp16.{suffix}" in files
): # vae, controlnet or some other standalone
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
location = self._download_hf_model(repo_id, files, staging)
break
elif f'diffusion_pytorch_model.{suffix}' in files:
files = ['config.json', f'diffusion_pytorch_model.{suffix}']
elif f"diffusion_pytorch_model.{suffix}" in files:
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging)
break
elif f'learned_embeds.{suffix}' in files:
location = self._download_hf_model(repo_id, [f'learned_embeds.{suffix}'], staging)
elif f"learned_embeds.{suffix}" in files:
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
break
if not location:
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
return {}
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
if not info:
logger.warning(f'Could not probe {location}. Skipping install.')
logger.warning(f"Could not probe {location}. Skipping install.")
return {}
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
dest = (
self.config.models_path
/ info.base_type.value
/ info.model_type.value
/ self._get_model_name(repo_id, location)
)
if dest.exists():
shutil.rmtree(dest)
shutil.copytree(location,dest)
shutil.copytree(location, dest)
return self._install_path(dest, info)
def _get_model_name(self,path_name: str, location: Path)->str:
'''
def _get_model_name(self, path_name: str, location: Path) -> str:
"""
Calculate a name for the model - primitive implementation.
'''
"""
if key := self.reverse_paths.get(path_name):
(name, base, mtype) = ModelManager.parse_key(key)
return name
@ -315,92 +339,103 @@ class ModelInstall(object):
else:
return location.stem
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
model_name = path.name if path.is_dir() else path.stem
description = f'{info.base_type.value} {info.model_type.value} model {model_name}'
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
if key := self.reverse_paths.get(self.current_id):
if key in self.datasets:
description = self.datasets[key].get('description') or description
description = self.datasets[key].get("description") or description
rel_path = self.relative_to_root(path)
attributes = dict(
path = str(rel_path),
description = str(description),
model_format = info.format,
)
path=str(rel_path),
description=str(description),
model_format=info.format,
)
legacy_conf = None
if info.model_type == ModelType.Main:
attributes.update(dict(variant = info.variant_type,))
if info.format=="checkpoint":
attributes.update(
dict(
variant=info.variant_type,
)
)
if info.format == "checkpoint":
try:
possible_conf = path.with_suffix('.yaml')
possible_conf = path.with_suffix(".yaml")
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
elif info.base_type == BaseModelType.StableDiffusion2:
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type])
legacy_conf = Path(
self.config.legacy_conf_dir,
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
)
else:
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type])
legacy_conf = Path(
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
)
except KeyError:
legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess
attributes.update(
dict(
config = str(legacy_conf)
)
)
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
possible_conf = path.with_suffix(".yaml")
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
if legacy_conf:
attributes.update(dict(config=str(legacy_conf)))
return attributes
def relative_to_root(self, path: Path)->Path:
def relative_to_root(self, path: Path) -> Path:
root = self.config.root_path
if path.is_relative_to(root):
return path.relative_to(root)
else:
return path
def _download_hf_pipeline(self, repo_id: str, staging: Path)->Path:
'''
def _download_hf_pipeline(self, repo_id: str, staging: Path) -> Path:
"""
This retrieves a StableDiffusion model from cache or remote and then
does a save_pretrained() to the indicated staging area.
'''
_,name = repo_id.split("/")
revisions = ['fp16','main'] if self.config.precision=='float16' else ['main']
"""
_, name = repo_id.split("/")
revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"]
model = None
for revision in revisions:
try:
model = DiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None)
except: # most errors are due to fp16 not being present. Fix this to catch other errors
pass
if model:
break
if not model:
logger.error(f'Diffusers model {repo_id} could not be downloaded. Skipping.')
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
return None
model.save_pretrained(staging / name, safe_serialization=True)
return staging / name
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path)->Path:
_,name = repo_id.split("/")
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path) -> Path:
_, name = repo_id.split("/")
location = staging / name
paths = list()
for filename in files:
p = hf_download_with_resume(repo_id,
model_dir=location,
model_name=filename,
access_token = self.access_token
)
p = hf_download_with_resume(
repo_id, model_dir=location, model_name=filename, access_token=self.access_token
)
if p:
paths.append(p)
else:
logger.warning(f'Could not download {filename} from {repo_id}.')
return location if len(paths)>0 else None
logger.warning(f"Could not download {filename} from {repo_id}.")
return location if len(paths) > 0 else None
@classmethod
def _reverse_paths(cls,datasets)->dict:
'''
def _reverse_paths(cls, datasets) -> dict:
"""
Reverse mapping from repo_id/path to destination name.
'''
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
"""
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
# -------------------------------------
def yes_or_no(prompt: str, default_yes=True):
@ -411,13 +446,12 @@ def yes_or_no(prompt: str, default_yes=True):
else:
return response[0] in ("y", "Y")
# ---------------------------------------------
def hf_download_from_pretrained(
model_class: object, model_name: str, destination: Path, **kwargs
):
logger = InvokeAILogger.getLogger('InvokeAI')
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
logger = InvokeAILogger.getLogger("InvokeAI")
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
model = model_class.from_pretrained(
model_name,
resume_download=True,
@ -426,13 +460,14 @@ def hf_download_from_pretrained(
model.save_pretrained(destination, safe_serialization=True)
return destination
# ---------------------------------------------
def hf_download_with_resume(
repo_id: str,
model_dir: str,
model_name: str,
model_dest: Path = None,
access_token: str = None,
repo_id: str,
model_dir: str,
model_name: str,
model_dest: Path = None,
access_token: str = None,
) -> Path:
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
os.makedirs(model_dir, exist_ok=True)
@ -451,9 +486,7 @@ def hf_download_with_resume(
resp = requests.get(url, headers=header, stream=True)
total = int(resp.headers.get("content-length", 0))
if (
resp.status_code == 416
): # "range not satisfiable", which means nothing to return
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
logger.info(f"{model_name}: complete file found. Skipping.")
return model_dest
elif resp.status_code == 404:
@ -482,5 +515,3 @@ def hf_download_with_resume(
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
return None
return model_dest

View File

@ -3,6 +3,12 @@ Initialization file for invokeai.backend.model_management
"""
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException, DuplicateModelException
from .models import (
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
ModelNotFoundException,
DuplicateModelException,
)
from .model_merge import ModelMerger, MergeInterpolationMethod

File diff suppressed because it is too large Load Diff

View File

@ -11,14 +11,15 @@ from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
class LoRALayerBase:
#rank: Optional[int]
#alpha: Optional[float]
#bias: Optional[torch.Tensor]
#layer_key: str
#@property
#def scale(self):
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
@ -31,11 +32,7 @@ class LoRALayerBase:
else:
self.alpha = None
if (
"bias_indices" in values
and "bias_values" in values
and "bias_size" in values
):
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
@ -45,13 +42,13 @@ class LoRALayerBase:
else:
self.bias = None
self.rank = None # set in layer implementation
self.rank = None # set in layer implementation
self.layer_key = layer_key
def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
@ -71,12 +68,16 @@ class LoRALayerBase:
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
) * multiplier * scale
return (
op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
)
* multiplier
* scale
)
def get_weight(self):
raise NotImplementedError()
@ -99,9 +100,9 @@ class LoRALayerBase:
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
#up: torch.Tensor
#mid: Optional[torch.Tensor]
#down: torch.Tensor
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
@ -151,12 +152,12 @@ class LoRALayer(LoRALayerBase):
class LoHALayer(LoRALayerBase):
#w1_a: torch.Tensor
#w1_b: torch.Tensor
#w2_a: torch.Tensor
#w2_b: torch.Tensor
#t1: Optional[torch.Tensor] = None
#t2: Optional[torch.Tensor] = None
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
@ -187,12 +188,8 @@ class LoHALayer(LoRALayerBase):
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum(
"i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a
)
rebuild2 = torch.einsum(
"i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a
)
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
@ -223,20 +220,20 @@ class LoHALayer(LoRALayerBase):
class LoKRLayer(LoRALayerBase):
#w1: Optional[torch.Tensor] = None
#w1_a: Optional[torch.Tensor] = None
#w1_b: Optional[torch.Tensor] = None
#w2: Optional[torch.Tensor] = None
#w2_a: Optional[torch.Tensor] = None
#w2_b: Optional[torch.Tensor] = None
#t2: Optional[torch.Tensor] = None
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
@ -266,7 +263,7 @@ class LoKRLayer(LoRALayerBase):
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
self.rank = None # unscaled
def get_weight(self):
w1 = self.w1
@ -278,7 +275,7 @@ class LoKRLayer(LoRALayerBase):
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
@ -317,7 +314,7 @@ class LoKRLayer(LoRALayerBase):
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoRAModel: #(torch.nn.Module):
class LoRAModel: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
@ -345,7 +342,7 @@ class LoRAModel: #(torch.nn.Module):
@property
def dtype(self):
return self._dtype
return self._dtype
def to(
self,
@ -380,7 +377,7 @@ class LoRAModel: #(torch.nn.Module):
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
name=file_path.stem, # TODO:
layers=dict(),
)
@ -392,7 +389,6 @@ class LoRAModel: #(torch.nn.Module):
state_dict = cls._group_state(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
@ -407,9 +403,7 @@ class LoRAModel: #(torch.nn.Module):
else:
# TODO: diff/ia3/... format
print(
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
)
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
return
# lower memory consumption by removing already parsed layer values
@ -443,9 +437,10 @@ with LoRAHelper.apply_lora_unet(unet, loras):
# unmodified unet
"""
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
@ -455,10 +450,10 @@ class ModelPatcher:
module = model
module_key = ""
key_parts = lora_key[len(prefix):].split('_')
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
@ -477,7 +472,6 @@ class ModelPatcher:
applied_loras: List[Tuple[LoRAModel, float]],
layer_name: str,
):
def lora_forward(module, input_h, output):
if len(applied_loras) == 0:
return output
@ -491,7 +485,6 @@ class ModelPatcher:
return lora_forward
@classmethod
@contextmanager
def apply_lora_unet(
@ -502,7 +495,6 @@ class ModelPatcher:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
@ -513,7 +505,6 @@ class ModelPatcher:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@classmethod
@contextmanager
def apply_lora(
@ -526,7 +517,7 @@ class ModelPatcher:
try:
with torch.no_grad():
for lora, lora_weight in loras:
#assert lora.device.type == "cpu"
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
@ -536,7 +527,7 @@ class ModelPatcher:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu
#with torch.autocast(device_type="cpu"):
# with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
@ -547,14 +538,13 @@ class ModelPatcher:
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
yield # wait for context manager exit
yield # wait for context manager exit
finally:
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)
@classmethod
@contextmanager
def apply_ti(
@ -602,7 +592,9 @@ class ModelPatcher:
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
)
model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype)
model_embeddings.weight.data[token_id] = embedding.to(
device=text_encoder.device, dtype=text_encoder.dtype
)
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
@ -614,7 +606,6 @@ class ModelPatcher:
if init_tokens_count and new_tokens_added:
text_encoder.resize_token_embeddings(init_tokens_count)
@classmethod
@contextmanager
def apply_clip_skip(
@ -633,9 +624,10 @@ class ModelPatcher:
while len(skipped_layers) > 0:
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
class TextualInversionModel:
name: str
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding: torch.Tensor # [n, 768]|[n, 1280]
@classmethod
def from_checkpoint(
@ -647,8 +639,8 @@ class TextualInversionModel:
if not isinstance(file_path, Path):
file_path = Path(file_path)
result = cls() # TODO:
result.name = file_path.stem # TODO:
result = cls() # TODO:
result.name = file_path.stem # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
@ -659,7 +651,9 @@ class TextualInversionModel:
# difference mostly in metadata
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(f"Warn: Embedding \"{file_path.name}\" contains multiple tokens, which is not supported. The first token will be used.")
print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first token will be used.'
)
result.embedding = next(iter(state_dict["string_to_param"].values()))
@ -688,10 +682,7 @@ class TextualInversionManager(BaseTextualInversionManager):
self.pad_tokens = dict()
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(
self, token_ids: list[int]
) -> list[int]:
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
if len(self.pad_tokens) == 0:
return token_ids
@ -707,4 +698,3 @@ class TextualInversionManager(BaseTextualInversionManager):
new_token_ids.extend(self.pad_tokens[token_id])
return new_token_ids

View File

@ -37,19 +37,22 @@ from .models import BaseModelType, ModelType, SubModelType, ModelBase
DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE= 2.75
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
class _CacheRecord:
size: int
model: Any
@ -79,22 +82,22 @@ class _CacheRecord:
return self.model.device != self.cache.storage_device
else:
return False
class ModelCache(object):
def __init__(
self,
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float=DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device=torch.device('cuda'),
storage_device: torch.device=torch.device('cpu'),
precision: torch.dtype=torch.float16,
sequential_offload: bool=False,
lazy_offloading: bool=True,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger
logger: types.ModuleType = logger,
):
'''
"""
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
@ -102,16 +105,16 @@ class ModelCache(object):
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
'''
"""
self.model_infos: Dict[str, ModelBase] = dict()
# allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype=precision
self.max_cache_size: float=max_cache_size
self.max_vram_cache_size: float=max_vram_cache_size
self.execution_device: torch.device=execution_device
self.storage_device: torch.device=storage_device
self.sha_chunksize=sha_chunksize
self.precision: torch.dtype = precision
self.max_cache_size: float = max_cache_size
self.max_vram_cache_size: float = max_vram_cache_size
self.execution_device: torch.device = execution_device
self.storage_device: torch.device = storage_device
self.sha_chunksize = sha_chunksize
self.logger = logger
self._cached_models = dict()
@ -124,7 +127,6 @@ class ModelCache(object):
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
):
key = f"{model_path}:{base_model}:{model_type}"
if submodel_type:
key += f":{submodel_type}"
@ -185,7 +187,7 @@ class ModelCache(object):
# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}')
self.logger.info(f"Loading model {model_path}, type {base_model}:{model_type}:{submodel}")
# this will remove older cached models until
# there is sufficient room to load the requested model
@ -195,7 +197,7 @@ class ModelCache(object):
gc.collect()
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
if mem_used := model_info.get_size(submodel):
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
self.logger.debug(f"CPU RAM used for load: {(mem_used/GIG):.2f} GB")
cache_entry = _CacheRecord(self, model, mem_used)
self._cached_models[key] = cache_entry
@ -208,13 +210,13 @@ class ModelCache(object):
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load, size_needed):
'''
"""
:param cache: The model_cache object
:param key: The key of the model to lock in GPU
:param model: The model to lock
:param gpu_load: True if load into gpu
:param size_needed: Size of the model to load
'''
"""
self.gpu_load = gpu_load
self.cache = cache
self.key = key
@ -223,7 +225,7 @@ class ModelCache(object):
self.cache_entry = self.cache._cached_models[self.key]
def __enter__(self) -> Any:
if not hasattr(self.model, 'to'):
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this
@ -233,22 +235,21 @@ class ModelCache(object):
try:
if self.cache.lazy_offloading:
self.cache._offload_unlocked_models(self.size_needed)
self.cache._offload_unlocked_models(self.size_needed)
if self.model.device != self.cache.execution_device:
self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}')
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
with VRAMUsage() as mem:
self.model.to(self.cache.execution_device) # move into GPU
self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}')
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
self.cache._print_cuda_stats()
except:
self.cache_entry.unlock()
raise
# TODO: not fully understand
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
@ -258,7 +259,7 @@ class ModelCache(object):
return self.model
def __exit__(self, type, value, traceback):
if not hasattr(self.model, 'to'):
if not hasattr(self.model, "to"):
return
self.cache_entry.unlock()
@ -276,11 +277,11 @@ class ModelCache(object):
self,
model_path: Union[str, Path],
) -> str:
'''
"""
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
:param model_path: Path to model file/directory on disk.
'''
"""
return self._local_model_hash(model_path)
def cache_size(self) -> float:
@ -289,7 +290,7 @@ class ModelCache(object):
return current_cache_size / GIG
def _has_cuda(self) -> bool:
return self.execution_device.type == 'cuda'
return self.execution_device.type == "cuda"
def _print_cuda_stats(self):
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
@ -305,18 +306,21 @@ class ModelCache(object):
if model_info.locked:
locked_models += 1
self.logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}")
self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
)
def _make_cache_room(self, model_size):
# calculate how much memory this model will require
#multiplier = 2 if self.precision==torch.float32 else 1
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = sum([m.size for m in self._cached_models.values()])
if current_size + bytes_needed > maximum_size:
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB"
)
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
@ -338,7 +342,7 @@ class ModelCache(object):
with suppress(RuntimeError):
referrer.clear()
cleared = True
#break
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
@ -347,13 +351,17 @@ class ModelCache(object):
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}"
)
# 2 refs:
# 1 from cache_entry
# 1 from getrefcount function
if not cache_entry.locked and refs <= 2:
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
del self._cache_stack[pos]
del self._cached_models[model_key]
@ -367,38 +375,36 @@ class ModelCache(object):
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
def _offload_unlocked_models(self, size_needed: int=0):
def _offload_unlocked_models(self, size_needed: int = 0):
reserved = self.max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x:x[1].size):
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.locked and cache_entry.loaded:
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
self.logger.debug(f"Offloading {model_key} from {self.execution_device} into {self.storage_device}")
with VRAMUsage() as mem:
cache_entry.model.to(self.storage_device)
self.logger.debug(f'GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB')
self.logger.debug(f"GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB")
vram_in_use += mem.vram_used # note vram_used is negative
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
gc.collect()
torch.cuda.empty_cache()
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256()
path = Path(model_path)
hashpath = path / "checksum.sha256"
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
with open(hashpath) as f:
hash = f.read()
return hash
self.logger.debug(f'computing hash of model {path.name}')
for file in list(path.rglob("*.ckpt")) \
+ list(path.rglob("*.safetensors")) \
+ list(path.rglob("*.pth")):
self.logger.debug(f"computing hash of model {path.name}")
for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")):
with open(file, "rb") as f:
while chunk := f.read(self.sha_chunksize):
sha.update(chunk)
@ -407,11 +413,12 @@ class ModelCache(object):
f.write(hash)
return hash
class VRAMUsage(object):
def __init__(self):
self.vram = None
self.vram_used = 0
def __enter__(self):
self.vram = torch.cuda.memory_allocated()
return self

View File

@ -249,20 +249,26 @@ from invokeai.backend.util import CUDA_DEVICE, Chdir
from .model_cache import ModelCache, ModelLocker
from .model_search import ModelSearch
from .models import (
BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES,
BaseModelType,
ModelType,
SubModelType,
ModelError,
SchedulerPredictionType,
MODEL_CLASSES,
ModelConfigBase,
ModelNotFoundException, InvalidModelException,
ModelNotFoundException,
InvalidModelException,
DuplicateModelException,
)
# We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help
# reduce confusion.
CONFIG_FILE_VERSION='3.0.0'
CONFIG_FILE_VERSION = "3.0.0"
@dataclass
class ModelInfo():
class ModelInfo:
context: ModelLocker
name: str
base_model: BaseModelType
@ -275,20 +281,24 @@ class ModelInfo():
def __enter__(self):
return self.context.__enter__()
def __exit__(self,*args, **kwargs):
def __exit__(self, *args, **kwargs):
self.context.__exit__(*args, **kwargs)
class AddModelResult(BaseModel):
name: str = Field(description="The name of the model after installation")
model_type: ModelType = Field(description="The type of model")
base_model: BaseModelType = Field(description="The base model")
config: ModelConfigBase = Field(description="The configuration of the model")
MAX_CACHE_SIZE = 6.0 # GB
class ConfigMeta(BaseModel):
version: str
class ModelManager(object):
"""
High-level interface to model management.
@ -315,12 +325,12 @@ class ModelManager(object):
if isinstance(config, (str, Path)):
self.config_path = Path(config)
if not self.config_path.exists():
logger.warning(f'The file {self.config_path} was not found. Initializing a new file')
logger.warning(f"The file {self.config_path} was not found. Initializing a new file")
self.initialize_model_config(self.config_path)
config = OmegaConf.load(self.config_path)
elif not isinstance(config, DictConfig):
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
raise ValueError("config argument must be an OmegaConf object, a Path or a string")
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
# TODO: metadata not found
@ -330,11 +340,11 @@ class ModelManager(object):
self.logger = logger
self.cache = ModelCache(
max_cache_size=max_cache_size,
max_vram_cache_size = self.app_config.max_vram_cache_size,
execution_device = device_type,
precision = precision,
sequential_offload = sequential_offload,
logger = logger,
max_vram_cache_size=self.app_config.max_vram_cache_size,
execution_device=device_type,
precision=precision,
sequential_offload=sequential_offload,
logger=logger,
)
self._read_models(config)
@ -348,7 +358,7 @@ class ModelManager(object):
self.models = dict()
for model_key, model_config in config.items():
if model_key.startswith('_'):
if model_key.startswith("_"):
continue
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
@ -399,7 +409,7 @@ class ModelManager(object):
@classmethod
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
base_model_str, model_type_str, model_name = model_key.split('/', 2)
base_model_str, model_type_str, model_name = model_key.split("/", 2)
try:
model_type = ModelType(model_type_str)
except:
@ -418,20 +428,16 @@ class ModelManager(object):
@classmethod
def initialize_model_config(cls, config_path: Path):
"""Create empty config file"""
with open(config_path,'w') as yaml_file:
yaml_file.write(yaml.dump({'__metadata__':
{'version':'3.0.0'}
}
)
)
with open(config_path, "w") as yaml_file:
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None
)->ModelInfo:
submodel_type: Optional[SubModelType] = None,
) -> ModelInfo:
"""Given a model named identified in models.yaml, return
an ModelInfo object describing it.
:param model_name: symbolic name of the model in models.yaml
@ -455,7 +461,7 @@ class ModelManager(object):
if not model_path.exists():
if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound
raise Exception(f"Files for model \"{model_key}\" not found")
raise Exception(f'Files for model "{model_key}" not found')
else:
self.models.pop(model_key, None)
@ -477,7 +483,7 @@ class ModelManager(object):
model_path = model_class.convert_if_required(
base_model=base_model,
model_path=str(model_path), # TODO: refactor str/Path types logic
model_path=str(model_path), # TODO: refactor str/Path types logic
output_path=dst_convert_path,
config=model_config,
)
@ -494,17 +500,17 @@ class ModelManager(object):
self.cache_keys[model_key] = set()
self.cache_keys[model_key].add(model_context.key)
model_hash = "<NO_HASH>" # TODO:
model_hash = "<NO_HASH>" # TODO:
return ModelInfo(
context = model_context,
name = model_name,
base_model = base_model,
type = submodel_type or model_type,
hash = model_hash,
location = model_path, # TODO:
precision = self.cache.precision,
_cache = self.cache,
context=model_context,
name=model_name,
base_model=base_model,
type=submodel_type or model_type,
hash=model_hash,
location=model_path, # TODO:
precision=self.cache.precision,
_cache=self.cache,
)
def model_info(
@ -520,7 +526,7 @@ class ModelManager(object):
if model_key in self.models:
return self.models[model_key].dict(exclude_defaults=True)
else:
return None # TODO: None or empty dict on not found
return None # TODO: None or empty dict on not found
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
@ -530,16 +536,16 @@ class ModelManager(object):
return [(self.parse_key(x)) for x in self.models.keys()]
def list_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> dict:
"""
Returns a dict describing one installed model, using
the combined format of the list_models() method.
"""
models = self.list_models(base_model,model_type,model_name)
models = self.list_models(base_model, model_type, model_name)
return models[0] if models else None
def list_models(
@ -552,13 +558,17 @@ class ModelManager(object):
Return a list of models.
"""
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
model_keys = (
[self.create_key(model_name, base_model, model_type)]
if model_name
else sorted(self.models, key=str.casefold)
)
models = []
for model_key in model_keys:
model_config = self.models.get(model_key)
if not model_config:
self.logger.error(f'Unknown model {model_name}')
raise ModelNotFoundException(f'Unknown model {model_name}')
self.logger.error(f"Unknown model {model_name}")
raise ModelNotFoundException(f"Unknown model {model_name}")
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model:
@ -575,8 +585,8 @@ class ModelManager(object):
)
# expose paths as absolute to help web UI
if path := model_dict.get('path'):
model_dict['path'] = str(self.app_config.root_path / path)
if path := model_dict.get("path"):
model_dict["path"] = str(self.app_config.root_path / path)
models.append(model_dict)
return models
@ -645,15 +655,15 @@ class ModelManager(object):
model_info().
"""
# relativize paths as they go in - this makes it easier to move the root directory around
if path := model_attributes.get('path'):
if path := model_attributes.get("path"):
if Path(path).is_relative_to(self.app_config.root_path):
model_attributes['path'] = str(Path(path).relative_to(self.app_config.root_path))
model_attributes["path"] = str(Path(path).relative_to(self.app_config.root_path))
model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type)
if model_key in self.models and not clobber:
if model_key in self.models and not clobber:
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
old_model = self.models.pop(model_key, None)
@ -677,24 +687,25 @@ class ModelManager(object):
self.models[model_key] = model_config
self.commit()
return AddModelResult(
name = model_name,
model_type = model_type,
base_model = base_model,
config = model_config,
name=model_name,
model_type=model_type,
base_model=base_model,
config=model_config,
)
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
):
'''
"""
Rename or rebase a model.
'''
"""
if new_name is None and new_base is None:
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
return
@ -713,7 +724,13 @@ class ModelManager(object):
# if this is a model file/directory that we manage ourselves, we need to move it
if old_path.is_relative_to(self.app_config.models_path):
new_path = self.app_config.root_path / 'models' / BaseModelType(new_base).value / ModelType(model_type).value / new_name
new_path = (
self.app_config.root_path
/ "models"
/ BaseModelType(new_base).value
/ ModelType(model_type).value
/ new_name
)
move(old_path, new_path)
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
@ -729,18 +746,18 @@ class ModelManager(object):
for cache_id in cache_ids:
self.cache.uncache_model(cache_id)
self.models.pop(model_key, None) # delete
self.models.pop(model_key, None) # delete
self.models[new_key] = model_cfg
self.commit()
def convert_model (
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
dest_directory: Optional[Path]=None,
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main, ModelType.Vae],
dest_directory: Optional[Path] = None,
) -> AddModelResult:
'''
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
@ -749,7 +766,7 @@ class ModelManager(object):
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is a checkpoint.
'''
"""
info = self.model_info(model_name, base_model, model_type)
if info["model_format"] != "checkpoint":
raise ValueError(f"not a checkpoint format model: {model_name}")
@ -757,27 +774,32 @@ class ModelManager(object):
# We are taking advantage of a side effect of get_model() that converts check points
# into cached diffusers directories stored at `location`. It doesn't matter
# what submodeltype we request here, so we get the smallest.
submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {}
model = self.get_model(model_name,
base_model,
model_type,
**submodel,
)
submodel = {"submodel_type": SubModelType.Scheduler} if model_type == ModelType.Main else {}
model = self.get_model(
model_name,
base_model,
model_type,
**submodel,
)
checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.app_config.models_path / model.location
new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
new_diffusers_path = (
dest_directory or self.app_config.models_path / base_model.value / model_type.value
) / model_name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
try:
move(old_diffusers_path,new_diffusers_path)
move(old_diffusers_path, new_diffusers_path)
info["model_format"] = "diffusers"
info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
info.pop('config')
info["path"] = (
str(new_diffusers_path)
if dest_directory
else str(new_diffusers_path.relative_to(self.app_config.root_path))
)
info.pop("config")
result = self.add_model(model_name, base_model, model_type,
model_attributes = info,
clobber=True)
result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True)
except:
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
rmtree(new_diffusers_path)
@ -801,15 +823,12 @@ class ModelManager(object):
found_models = []
for file in files:
location = str(file.resolve()).replace("\\", "/")
if (
"model.safetensors" not in location
and "diffusion_pytorch_model.safetensors" not in location
):
if "model.safetensors" not in location and "diffusion_pytorch_model.safetensors" not in location:
found_models.append({"name": file.stem, "location": location})
return search_folder, found_models
def commit(self, conf_file: Path=None) -> None:
def commit(self, conf_file: Path = None) -> None:
"""
Write current configuration out to the indicated file.
"""
@ -827,7 +846,7 @@ class ModelManager(object):
yaml_str = OmegaConf.to_yaml(data_to_save)
config_file_path = conf_file or self.config_path
assert config_file_path is not None,'no config file path to write to'
assert config_file_path is not None, "no config file path to write to"
config_file_path = self.app_config.root_path / config_file_path
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
try:
@ -844,7 +863,7 @@ class ModelManager(object):
Returns the preamble for the config file.
"""
return textwrap.dedent(
"""\
"""
# This file describes the alternative machine learning models
# available to InvokeAI script.
#
@ -860,11 +879,10 @@ class ModelManager(object):
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
):
loaded_files = set()
new_models_found = False
self.logger.info(f'Scanning {self.app_config.models_path} for new models')
self.logger.info(f"Scanning {self.app_config.models_path} for new models")
with Chdir(self.app_config.root_path):
for model_key, model_config in list(self.models.items()):
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
@ -890,10 +908,10 @@ class ModelManager(object):
models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value
if not models_dir.exists():
continue # TODO: or create all folders?
continue # TODO: or create all folders?
for model_path in models_dir.iterdir():
if model_path not in loaded_files: # TODO: check
if model_path not in loaded_files: # TODO: check
model_name = model_path.name if model_path.is_dir() else model_path.stem
model_key = self.create_key(model_name, cur_base_model, cur_model_type)
@ -903,7 +921,7 @@ class ModelManager(object):
if model_path.is_relative_to(self.app_config.root_path):
model_path = model_path.relative_to(self.app_config.root_path)
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config
new_models_found = True
@ -919,11 +937,10 @@ class ModelManager(object):
if (new_models_found or imported_models) and self.config_path:
self.commit()
def autoimport(self)->Dict[str, AddModelResult]:
'''
def autoimport(self) -> Dict[str, AddModelResult]:
"""
Scan the autoimport directory (if defined) and import new models, delete defunct models.
'''
"""
# avoid circular import
from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
@ -942,7 +959,9 @@ class ModelManager(object):
self.new_models_found.update(self.installer.heuristic_import(model))
def on_search_completed(self):
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
self.logger.info(
f"Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models"
)
def models_found(self):
return self.new_models_found
@ -952,31 +971,37 @@ class ModelManager(object):
# LS: hacky
# Patch in the SD VAE from core so that it is available for use by the UI
try:
self.heuristic_import({config.root_path / 'models/core/convert/sd-vae-ft-mse'})
self.heuristic_import({config.root_path / "models/core/convert/sd-vae-ft-mse"})
except:
pass
installer = ModelInstall(config = self.app_config,
model_manager = self,
prediction_type_helper = ask_user_for_prediction_type,
)
known_paths = {config.root_path / x['path'] for x in self.list_models()}
directories = {config.root_path / x for x in [config.autoimport_dir,
config.lora_dir,
config.embedding_dir,
config.controlnet_dir,
] if x
}
installer = ModelInstall(
config=self.app_config,
model_manager=self,
prediction_type_helper=ask_user_for_prediction_type,
)
known_paths = {config.root_path / x["path"] for x in self.list_models()}
directories = {
config.root_path / x
for x in [
config.autoimport_dir,
config.lora_dir,
config.embedding_dir,
config.controlnet_dir,
]
if x
}
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
scanner.search()
return scanner.models_found()
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
def heuristic_import(
self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
) -> Dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
@ -995,14 +1020,15 @@ class ModelManager(object):
May return the following exceptions:
- ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
- ValueError - a corresponding model already exists
'''
"""
# avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = dict()
installer = ModelInstall(config = self.app_config,
prediction_type_helper = prediction_type_helper,
model_manager = self)
installer = ModelInstall(
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
)
for thing in items_to_import:
installed = installer.heuristic_import(thing)
successfully_installed.update(installed)

View File

@ -17,23 +17,25 @@ import invokeai.backend.util.logging as logger
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
class MergeInterpolationMethod(str, Enum):
WeightedSum = "weighted_sum"
Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference"
class ModelMerger(object):
def __init__(self, manager: ModelManager):
self.manager = manager
def merge_diffusion_models(
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
@ -58,24 +60,23 @@ class ModelMerger(object):
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_paths,
alpha=alpha,
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
interp=interp.value if interp else None, # diffusers API treats None as "weighted sum"
force=force,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_save (
self,
model_names: List[str],
base_model: Union[BaseModelType,str],
merged_model_name: str,
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
merge_dest_directory: Optional[Path] = None,
**kwargs,
def merge_diffusion_models_and_save(
self,
model_names: List[str],
base_model: Union[BaseModelType, str],
merged_model_name: str,
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
merge_dest_directory: Optional[Path] = None,
**kwargs,
) -> AddModelResult:
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
@ -94,39 +95,45 @@ class ModelMerger(object):
config = self.manager.app_config
base_model = BaseModelType(base_model)
vae = None
for mod in model_names:
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
assert len(model_names) <= 2 or \
interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported"
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert (
info["model_format"] == "diffusers"
), f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
assert (
len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference
), "When merging three models, only the 'add_difference' merge method is supported"
# pick up the first model's vae
if mod == model_names[0]:
vae = info.get("vae")
model_paths.extend([config.root_path / info["path"]])
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
logger.debug(f'interp = {interp}, merge_method={merge_method}')
merged_pipe = self.merge_diffusion_models(
model_paths, alpha, merge_method, force, **kwargs
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
logger.debug(f"interp = {interp}, merge_method={merge_method}")
merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, **kwargs)
dump_path = (
Path(merge_dest_directory)
if merge_dest_directory
else config.models_path / base_model.value / ModelType.Main.value
)
dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
attributes = dict(
path = str(dump_path),
description = f"Merge of models {', '.join(model_names)}",
model_format = "diffusers",
variant = ModelVariantType.Normal.value,
vae = vae,
path=str(dump_path),
description=f"Merge of models {', '.join(model_names)}",
model_format="diffusers",
variant=ModelVariantType.Normal.value,
vae=vae,
)
return self.manager.add_model(
merged_model_name,
base_model=base_model,
model_type=ModelType.Main,
model_attributes=attributes,
clobber=True,
)
return self.manager.add_model(merged_model_name,
base_model = base_model,
model_type = ModelType.Main,
model_attributes = attributes,
clobber = True
)

View File

@ -10,12 +10,16 @@ from typing import Callable, Literal, Union, Dict, Optional
from picklescan.scanner import scan_file_path
from .models import (
BaseModelType, ModelType, ModelVariantType,
SchedulerPredictionType, SilenceWarnings,
InvalidModelException
BaseModelType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SilenceWarnings,
InvalidModelException,
)
from .models.base import read_checkpoint_meta
@dataclass
class ModelProbeInfo(object):
model_type: ModelType
@ -23,70 +27,74 @@ class ModelProbeInfo(object):
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
format: Literal['diffusers','checkpoint', 'lycoris']
format: Literal["diffusers", "checkpoint", "lycoris"]
image_size: int
class ProbeBase(object):
'''forward declaration'''
"""forward declaration"""
pass
class ModelProbe(object):
PROBES = {
'diffusers': { },
'checkpoint': { },
"diffusers": {},
"checkpoint": {},
}
CLASS2TYPE = {
'StableDiffusionPipeline' : ModelType.Main,
'StableDiffusionInpaintPipeline' : ModelType.Main,
'StableDiffusionXLPipeline' : ModelType.Main,
'StableDiffusionXLImg2ImgPipeline' : ModelType.Main,
'AutoencoderKL' : ModelType.Vae,
'ControlNetModel' : ModelType.ControlNet,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
}
@classmethod
def register_probe(cls,
format: Literal['diffusers','checkpoint'],
model_type: ModelType,
probe_class: ProbeBase):
def register_probe(cls, format: Literal["diffusers", "checkpoint"], model_type: ModelType, probe_class: ProbeBase):
cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(cls,
model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->ModelProbeInfo:
if isinstance(model,Path):
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
def heuristic_probe(
cls,
model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
) -> ModelProbeInfo:
if isinstance(model, Path):
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
@classmethod
def probe(cls,
model_path: Path,
model: Optional[Union[Dict, ModelMixin]] = None,
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]] = None)->ModelProbeInfo:
'''
def probe(
cls,
model_path: Path,
model: Optional[Union[Dict, ModelMixin]] = None,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> ModelProbeInfo:
"""
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The prediction_type_helper callable is a function that receives
the path to the model and returns the BaseModelType. It is called to distinguish
between V2-Base and V2-768 SD models.
'''
"""
if model_path:
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
else:
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
model_info = None
try:
model_type = cls.get_model_type_from_folder(model_path, model) \
if format_type == 'diffusers' \
else cls.get_model_type_from_checkpoint(model_path, model)
model_type = (
cls.get_model_type_from_folder(model_path, model)
if format_type == "diffusers"
else cls.get_model_type_from_checkpoint(model_path, model)
)
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
return None
@ -96,17 +104,23 @@ class ModelProbe(object):
prediction_type = probe.get_scheduler_prediction_type()
format = probe.get_format()
model_info = ModelProbeInfo(
model_type = model_type,
base_type = base_type,
variant_type = variant_type,
prediction_type = prediction_type,
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction),
format = format,
image_size = 1024 if (base_type in {BaseModelType.StableDiffusionXL,BaseModelType.StableDiffusionXLRefiner}) else \
768 if (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction ) else \
512
model_type=model_type,
base_type=base_type,
variant_type=variant_type,
prediction_type=prediction_type,
upcast_attention=(
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
),
format=format,
image_size=1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else 768
if (
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
)
else 512,
)
except Exception:
raise
@ -115,7 +129,7 @@ class ModelProbe(object):
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors','.pth'):
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
return None
if model_path.name == "learned_embeds.bin":
@ -142,32 +156,32 @@ class ModelProbe(object):
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise InvalidModelException(f"Unable to determine model type for {model_path}")
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
'''
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
"""
Get the model type of a hugging-face style folder.
'''
"""
class_name = None
if model:
class_name = model.__class__.__name__
else:
if (folder_path / 'learned_embeds.bin').exists():
if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion
if (folder_path / 'pytorch_lora_weights.bin').exists():
if (folder_path / "pytorch_lora_weights.bin").exists():
return ModelType.Lora
i = folder_path / 'model_index.json'
c = folder_path / 'config.json'
i = folder_path / "model_index.json"
c = folder_path / "config.json"
config_path = i if i.exists() else c if c.exists() else None
if config_path:
with open(config_path,'r') as file:
with open(config_path, "r") as file:
conf = json.load(file)
class_name = conf['_class_name']
class_name = conf["_class_name"]
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
@ -176,7 +190,7 @@ class ModelProbe(object):
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
@classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path)
@ -186,55 +200,53 @@ class ModelProbe(object):
@classmethod
def _scan_model(cls, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
###################################################3
# Checkpoint probing
###################################################3
class ProbeBase(object):
def get_base_type(self)->BaseModelType:
def get_base_type(self) -> BaseModelType:
pass
def get_variant_type(self)->ModelVariantType:
pass
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
def get_variant_type(self) -> ModelVariantType:
pass
def get_format(self)->str:
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
pass
def get_format(self) -> str:
pass
class CheckpointProbeBase(ProbeBase):
def __init__(self,
checkpoint_path: Path,
checkpoint: dict,
helper: Callable[[Path],SchedulerPredictionType] = None
)->BaseModelType:
def __init__(
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
) -> BaseModelType:
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.checkpoint_path = checkpoint_path
self.helper = helper
def get_base_type(self)->BaseModelType:
def get_base_type(self) -> BaseModelType:
pass
def get_format(self)->str:
return 'checkpoint'
def get_format(self) -> str:
return "checkpoint"
def get_variant_type(self)-> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
in_channels = state_dict[
"model.diffusion_model.input_blocks.0.0.weight"
].shape[1]
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
@ -242,51 +254,60 @@ class CheckpointProbeBase(ProbeBase):
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise InvalidModelException(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
raise InvalidModelException(
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
)
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get('state_dict') or checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
raise InvalidModelException("Cannot determine base type")
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
type = self.get_base_type()
if type == BaseModelType.StableDiffusion1:
return SchedulerPredictionType.Epsilon
checkpoint = self.checkpoint
state_dict = self.checkpoint.get('state_dict') or checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if 'global_step' in checkpoint:
if checkpoint['global_step'] == 220000:
if "global_step" in checkpoint:
if checkpoint["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
if self.checkpoint_path and self.helper \
and not self.checkpoint_path.with_suffix('.yaml').exists(): # if a .yaml config file exists, then this step not needed
if (
self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists()
): # if a .yaml config file exists, then this step not needed
return self.helper(self.checkpoint_path)
else:
return None
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
def get_base_type(self) -> BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
def get_format(self)->str:
return 'lycoris'
def get_base_type(self)->BaseModelType:
class LoRACheckpointProbe(CheckpointProbeBase):
def get_format(self) -> str:
return "lycoris"
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
@ -304,16 +325,17 @@ class LoRACheckpointProbe(CheckpointProbeBase):
else:
return None
class TextualInversionCheckpointProbe(CheckpointProbeBase):
def get_format(self)->str:
def get_format(self) -> str:
return None
def get_base_type(self)->BaseModelType:
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if 'string_to_token' in checkpoint:
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
elif 'emb_params' in checkpoint:
token_dim = checkpoint['emb_params'].shape[-1]
if "string_to_token" in checkpoint:
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
elif "emb_params" in checkpoint:
token_dim = checkpoint["emb_params"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
@ -323,12 +345,14 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
else:
return None
class ControlNetCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
):
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
@ -339,56 +363,54 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
return self.helper(self.checkpoint_path)
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def __init__(self,
folder_path: Path,
model: ModelMixin = None,
helper: Callable=None # not used
):
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
self.model = model
self.folder_path = folder_path
def get_variant_type(self)->ModelVariantType:
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
def get_format(self)->str:
return 'diffusers'
def get_format(self) -> str:
return "diffusers"
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
def get_base_type(self) -> BaseModelType:
if self.model:
unet_conf = self.model.unet.config
else:
with open(self.folder_path / 'unet' / 'config.json','r') as file:
with open(self.folder_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf['cross_attention_dim'] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf['cross_attention_dim'] == 1024:
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf['cross_attention_dim'] == 1280:
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf['cross_attention_dim'] == 2048:
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f'Unknown base model for {self.folder_path}')
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
if self.model:
scheduler_conf = self.model.scheduler.config
else:
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
if scheduler_conf['prediction_type'] == "v_prediction":
if scheduler_conf["prediction_type"] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf['prediction_type'] == 'epsilon':
elif scheduler_conf["prediction_type"] == "epsilon":
return SchedulerPredictionType.Epsilon
else:
return None
def get_variant_type(self)->ModelVariantType:
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
@ -396,11 +418,11 @@ class PipelineFolderProbe(FolderProbeBase):
if self.model:
conf = self.model.unet.config
else:
config_file = self.folder_path / 'unet' / 'config.json'
with open(config_file,'r') as file:
config_file = self.folder_path / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
in_channels = conf['in_channels']
in_channels = conf["in_channels"]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
@ -411,53 +433,67 @@ class PipelineFolderProbe(FolderProbeBase):
pass
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
return BaseModelType.StableDiffusion1
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
return (
BaseModelType.StableDiffusionXL
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
else BaseModelType.StableDiffusion1
)
class TextualInversionFolderProbe(FolderProbeBase):
def get_format(self)->str:
def get_format(self) -> str:
return None
def get_base_type(self)->BaseModelType:
path = self.folder_path / 'learned_embeds.bin'
def get_base_type(self) -> BaseModelType:
path = self.folder_path / "learned_embeds.bin"
if not path.exists():
return None
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
return TextualInversionCheckpointProbe(None,checkpoint=checkpoint).get_base_type()
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'config.json'
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file,'r') as file:
with open(config_file, "r") as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
return BaseModelType.StableDiffusion1 \
if config['cross_attention_dim']==768 \
else BaseModelType.StableDiffusion2
return (
BaseModelType.StableDiffusion1 if config["cross_attention_dim"] == 768 else BaseModelType.StableDiffusion2
)
class LoRAFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
def get_base_type(self) -> BaseModelType:
model_file = None
for suffix in ['safetensors','bin']:
base_file = self.folder_path / f'pytorch_lora_weights.{suffix}'
for suffix in ["safetensors", "bin"]:
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
if base_file.exists():
model_file = base_file
break
if not model_file:
raise InvalidModelException('Unknown LoRA format encountered')
return LoRACheckpointProbe(model_file,None).get_base_type()
raise InvalidModelException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file, None).get_base_type()
############## register probe classes ######
ModelProbe.register_probe('diffusers', ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe('diffusers', ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe('diffusers', ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe('diffusers', ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe('diffusers', ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe('checkpoint', ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)

View File

@ -10,8 +10,9 @@ from pathlib import Path
import invokeai.backend.util.logging as logger
class ModelSearch(ABC):
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
"""
Initialize a recursive model directory search.
:param directories: List of directory Paths to recurse through
@ -56,18 +57,23 @@ class ModelSearch(ABC):
def walk_directory(self, path: Path):
for root, dirs, files in os.walk(path):
if str(Path(root).name).startswith('.'):
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue
self._items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in self._scanned_paths or path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
if any(
[
(path / x).exists()
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
]
):
try:
self.on_model_found(path)
self._models_found += 1
@ -79,18 +85,19 @@ class ModelSearch(ABC):
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(str(e))
class FindModels(ModelSearch):
def on_search_started(self):
self.models_found: Set[Path] = set()
def on_model_found(self,model: Path):
def on_model_found(self, model: Path):
self.models_found.add(model)
def on_search_completed(self):
@ -99,5 +106,3 @@ class FindModels(ModelSearch):
def list_models(self) -> List[Path]:
self.search()
return list(self.models_found)

View File

@ -3,15 +3,24 @@ from enum import Enum
from pydantic import BaseModel
from typing import Literal, get_origin
from .base import (
BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase,
ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings,
ModelNotFoundException, InvalidModelException, DuplicateModelException
)
BaseModelType,
ModelType,
SubModelType,
ModelBase,
ModelConfigBase,
ModelVariantType,
SchedulerPredictionType,
ModelError,
SilenceWarnings,
ModelNotFoundException,
InvalidModelException,
DuplicateModelException,
)
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .sdxl import StableDiffusionXLModel
from .vae import VaeModel
from .lora import LoRAModel
from .controlnet import ControlNetModel # TODO:
from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel
MODEL_CLASSES = {
@ -45,18 +54,19 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModelType.Kandinsky2_1: {
# BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoRAModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
#},
# },
}
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
class OpenAPIModelInfoBase(BaseModel):
model_name: str
base_model: BaseModelType
@ -72,27 +82,31 @@ for base_model, models in MODEL_CLASSES.items():
# LS: sort to get the checkpoint configs first, which makes
# for a better template in the Swagger docs
for cfg in sorted(model_configs, key=lambda x: str(x)):
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
model_name, cfg_name = cfg.__qualname__.split(".")[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():
continue
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
__annotations__ = dict(
model_type=Literal[model_type.value],
api_wrapper = type(
openapi_cfg_name,
(cfg, OpenAPIModelInfoBase),
dict(
__annotations__=dict(
model_type=Literal[model_type.value],
),
),
))
)
#globals()[openapi_cfg_name] = api_wrapper
# globals()[openapi_cfg_name] = api_wrapper
vars()[openapi_cfg_name] = api_wrapper
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
def get_model_config_enums():
enums = list()
for model_config in MODEL_CONFIGS:
if hasattr(inspect,'get_annotations'):
if hasattr(inspect, "get_annotations"):
fields = inspect.get_annotations(model_config)
else:
fields = model_config.__annotations__
@ -109,7 +123,9 @@ def get_model_config_enums():
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
enums.append(field)
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
elif get_origin(field) is Literal and all(
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
):
enums.append(type(field.__args__[0]))
elif field is None:
@ -119,4 +135,3 @@ def get_model_config_enums():
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
return enums

View File

@ -15,29 +15,35 @@ from contextlib import suppress
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class DuplicateModelException(Exception):
pass
class InvalidModelException(Exception):
pass
class ModelNotFoundException(Exception):
pass
class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
#Kandinsky2_1 = "kandinsky-2.1"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
@ -47,23 +53,27 @@ class SubModelType(str, Enum):
Vae = "vae"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
#MoVQ = "movq"
# MoVQ = "movq"
class ModelVariantType(str, Enum):
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class SchedulerPredictionType(str, Enum):
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelError(str, Enum):
NotFound = "not_found"
class ModelConfigBase(BaseModel):
path: str # or Path
path: str # or Path
description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
error: Optional[ModelError] = Field(None)
@ -71,13 +81,17 @@ class ModelConfigBase(BaseModel):
class Config:
use_enum_values = True
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
T_co = TypeVar('T_co', covariant=True)
T_co = TypeVar("T_co", covariant=True)
class classproperty(Generic[T_co]):
def __init__(self, fget: Callable[[Any], T_co]) -> None:
self.fget = fget
@ -86,12 +100,13 @@ class classproperty(Generic[T_co]):
return self.fget(owner)
def __set__(self, instance: Optional[Any], value: Any) -> None:
raise AttributeError('cannot set attribute')
raise AttributeError("cannot set attribute")
class ModelBase(metaclass=ABCMeta):
#model_path: str
#base_model: BaseModelType
#model_type: ModelType
# model_path: str
# base_model: BaseModelType
# model_type: ModelType
def __init__(
self,
@ -110,7 +125,7 @@ class ModelBase(metaclass=ABCMeta):
return None
elif any(t is None for t in subtypes):
raise Exception(f"Unsupported definition: {subtypes}")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
@ -119,7 +134,6 @@ class ModelBase(metaclass=ABCMeta):
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
@ -128,7 +142,7 @@ class ModelBase(metaclass=ABCMeta):
def _get_configs(cls):
with suppress(Exception):
return cls.__configs
configs = dict()
for name in dir(cls):
if name.startswith("__"):
@ -138,7 +152,7 @@ class ModelBase(metaclass=ABCMeta):
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
continue
if hasattr(inspect,'get_annotations'):
if hasattr(inspect, "get_annotations"):
fields = inspect.get_annotations(value)
else:
fields = value.__annotations__
@ -151,7 +165,9 @@ class ModelBase(metaclass=ABCMeta):
for model_format in field:
configs[model_format.value] = value
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
elif typing.get_origin(field) is Literal and all(
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
):
for model_format in field.__args__:
configs[model_format.value] = value
@ -203,8 +219,8 @@ class ModelBase(metaclass=ABCMeta):
class DiffusersModel(ModelBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
# child_types: Dict[str, Type]
# child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
@ -214,7 +230,7 @@ class DiffusersModel(ModelBase):
try:
config_data = DiffusionPipeline.load_config(self.model_path)
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
@ -228,14 +244,12 @@ class DiffusersModel(ModelBase):
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
@ -245,7 +259,7 @@ class DiffusersModel(ModelBase):
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
return None # TODO: or raise
if torch_dtype == torch.float16:
variants = ["fp16", None]
@ -265,8 +279,8 @@ class DiffusersModel(ModelBase):
)
break
except Exception as e:
#print("====ERR LOAD====")
#print(f"{variant}: {e}")
# print("====ERR LOAD====")
# print(f"{variant}: {e}")
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
@ -275,15 +289,10 @@ class DiffusersModel(ModelBase):
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
#def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
# def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
def calc_model_size_by_fs(
model_path: str,
subfolder: Optional[str] = None,
variant: Optional[str] = None
):
def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, variant: Optional[str] = None):
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
@ -325,12 +334,12 @@ def calc_model_size_by_fs(
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
@ -343,9 +352,9 @@ def calc_model_size_by_fs(
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
# raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
@ -364,12 +373,12 @@ def _calc_pipeline_by_data(pipeline) -> int:
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem
@ -377,11 +386,15 @@ def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), 'little')
definition_len = int.from_bytes(f.read(8), "little")
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
"pt",
"torch",
"pytorch",
}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
@ -400,6 +413,7 @@ def _fast_safetensors_reader(path: str):
return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"):
try:
@ -411,25 +425,27 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if scan:
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
import warnings
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
class SilenceWarnings(object):
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter('ignore')
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default')
warnings.simplefilter("default")

View File

@ -1,7 +1,8 @@
import os
import torch
from enum import Enum
from typing import Optional
from pathlib import Path
from typing import Optional, Literal
from .base import (
ModelBase,
ModelConfigBase,
@ -15,17 +16,24 @@ from .base import (
InvalidModelException,
ModelNotFoundException,
)
from invokeai.app.services.config import InvokeAIAppConfig
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
model_format: ControlNetModelFormat
class ControlNetModel(ModelBase):
# model_class: Type
# model_size: int
class DiffusersConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Diffusers]
class CheckpointConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Checkpoint]
config: str
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
@ -33,7 +41,7 @@ class ControlNetModel(ModelBase):
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
# config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
@ -61,7 +69,7 @@ class ControlNetModel(ModelBase):
raise Exception("There is no child models in controlnet model")
model = None
for variant in ['fp16',None]:
for variant in ["fp16", None]:
try:
model = self.model_class.from_pretrained(
self.model_path,
@ -73,7 +81,7 @@ class ControlNetModel(ModelBase):
pass
if not model:
raise ModelNotFoundException()
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@ -102,10 +110,50 @@ class ControlNetModel(ModelBase):
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
return _convert_controlnet_ckpt_and_cache(
model_path=model_path,
model_config=config.config,
output_path=output_path,
base_model=base_model,
)
else:
return model_path
@classmethod
def _convert_controlnet_ckpt_and_cache(
cls,
model_path: str,
output_path: str,
base_model: BaseModelType,
model_config: ControlNetModel.CheckpointConfig,
) -> str:
"""
Convert the controlnet from checkpoint format to diffusers format,
cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_path / model_path
output_path = Path(output_path)
# return cached version if it exists
if output_path.exists():
return output_path
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
convert_controlnet_to_diffusers(
weights,
output_path,
original_config_file=app_config.root_path / model_config,
image_size=512,
scan_needed=True,
from_safetensors=weights.suffix == ".safetensors",
)
return output_path

View File

@ -12,18 +12,21 @@ from .base import (
InvalidModelException,
ModelNotFoundException,
)
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
Diffusers = "diffusers"
class LoRAModel(ModelBase):
#model_size: int
# model_size: int
class Config(ModelConfigBase):
model_format: LoRAModelFormat # TODO:
model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora

View File

@ -1,5 +1,6 @@
import os
import json
import invokeai.backend.util.logging as logger
from enum import Enum
from pydantic import Field
from typing import Literal, Optional
@ -14,12 +15,13 @@ from .base import (
)
from omegaconf import OmegaConf
class StableDiffusionXLModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusionXLModel(DiffusersModel):
class StableDiffusionXLModel(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
@ -48,11 +50,11 @@ class StableDiffusionXLModel(DiffusersModel):
if model_format == StableDiffusionXLModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusionXLModelFormat.Diffusers:
@ -60,7 +62,7 @@ class StableDiffusionXLModel(DiffusersModel):
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
in_channels = unet_config["in_channels"]
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
@ -80,11 +82,10 @@ class StableDiffusionXLModel(DiffusersModel):
if ckpt_config_path is None:
# TO DO: implement picking
pass
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@ -108,7 +109,17 @@ class StableDiffusionXLModel(DiffusersModel):
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
# The convert script adapted from the diffusers package uses
# strings for the base model type. To avoid making too many
# source code changes, we simply translate here
if isinstance(config, cls.CheckpointConfig):
raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported')
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
return _convert_ckpt_and_cache(
version=base_model,
model_config=config,
output_path=output_path,
use_safetensors=False, # corrupts sdxl models for some reason
)
else:
return model_path

View File

@ -14,16 +14,20 @@ from .base import (
read_checkpoint_meta,
classproperty,
InvalidModelException,
ModelNotFoundException,
)
from .sdxl import StableDiffusionXLModel
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel):
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
@ -34,7 +38,7 @@ class StableDiffusion1Model(DiffusersModel):
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.Main
@ -55,7 +59,7 @@ class StableDiffusion1Model(DiffusersModel):
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusion1ModelFormat.Diffusers:
@ -63,7 +67,7 @@ class StableDiffusion1Model(DiffusersModel):
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
in_channels = unet_config["in_channels"]
else:
raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
@ -84,7 +88,6 @@ class StableDiffusion1Model(DiffusersModel):
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@ -121,16 +124,17 @@ class StableDiffusion1Model(DiffusersModel):
version=BaseModelType.StableDiffusion1,
model_config=config,
output_path=output_path,
)
)
else:
return model_path
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel):
class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
@ -163,7 +167,7 @@ class StableDiffusion2Model(DiffusersModel):
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusion2ModelFormat.Diffusers:
@ -171,7 +175,7 @@ class StableDiffusion2Model(DiffusersModel):
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
in_channels = unet_config["in_channels"]
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
@ -194,7 +198,6 @@ class StableDiffusion2Model(DiffusersModel):
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@ -235,42 +238,19 @@ class StableDiffusion2Model(DiffusersModel):
else:
return model_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
# note that these .yaml files don't yet exist!
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "xl-inference-v.yaml",
ModelVariantType.Inpaint: "xl-inpainting-inference.yaml",
ModelVariantType.Depth: "xl-midas-inference.yaml",
}
}
app_config = InvokeAIAppConfig.get_config()
try:
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
if config_path.is_relative_to(app_config.root_path):
config_path = config_path.relative_to(app_config.root_path)
return str(config_path)
except:
return None
# TODO: rework
# Note that convert_ckpt_to_diffuses does not currently support conversion of SDXL models
# pass precision - currently defaulting to fp16
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
model_config: Union[
StableDiffusion1Model.CheckpointConfig,
StableDiffusion2Model.CheckpointConfig,
StableDiffusionXLModel.CheckpointConfig,
],
output_path: str,
use_save_model: bool = False,
**kwargs,
) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
@ -289,14 +269,61 @@ def _convert_ckpt_and_cache(
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings():
from ...util.devices import choose_torch_device, torch_dtype
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
logger.info(f"Converting {weights} to diffusers format")
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
output_path,
model_type=model_base_to_model_type[version],
model_version=version,
model_variant=model_config.variant,
original_config_file=config_file,
extract_ema=True,
scan_needed=True,
from_safetensors=weights.suffix == ".safetensors",
precision=torch_dtype(choose_torch_device()),
**kwargs,
)
return output_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
ModelVariantType.Inpaint: None,
ModelVariantType.Depth: None,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
ModelVariantType.Inpaint: None,
ModelVariantType.Depth: None,
},
}
app_config = InvokeAIAppConfig.get_config()
try:
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
if config_path.is_relative_to(app_config.root_path):
config_path = config_path.relative_to(app_config.root_path)
return str(config_path)
except:
return None

View File

@ -11,11 +11,13 @@ from .base import (
ModelNotFoundException,
InvalidModelException,
)
# TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw
class TextualInversionModel(ModelBase):
#model_size: int
# model_size: int
class Config(ModelConfigBase):
model_format: None
@ -65,7 +67,7 @@ class TextualInversionModel(ModelBase):
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
return None # diffusers-ti
return None # diffusers-ti
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):

View File

@ -22,13 +22,15 @@ from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
# vae_class: Type
# model_size: int
class Config(ModelConfigBase):
model_format: VaeModelFormat
@ -39,7 +41,7 @@ class VaeModel(ModelBase):
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
# config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid vae model! (config.json not found or invalid)")
@ -95,7 +97,7 @@ class VaeModel(ModelBase):
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
@ -108,6 +110,7 @@ class VaeModel(ModelBase):
else:
return model_path
# TODO: rework
def _convert_vae_ckpt_and_cache(
weights_path: str,
@ -138,13 +141,14 @@ def _convert_vae_ckpt_and_cache(
2.1 - 768
"""
image_size = 512
# return cached version if it exists
if output_path.exists():
return output_path
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
from .stable_diffusion import _select_ckpt_config
# all sd models use same vae settings
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
else:
@ -152,7 +156,8 @@ def _convert_vae_ckpt_and_cache(
# this avoids circular import error
from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
if weights_path.suffix == '.safetensors':
if weights_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
else:
checkpoint = torch.load(weights_path, map_location="cpu")
@ -161,15 +166,12 @@ def _convert_vae_ckpt_and_cache(
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
config = OmegaConf.load(app_config.root_path/config_file)
config = OmegaConf.load(app_config.root_path / config_file)
vae_model = convert_ldm_vae_to_diffusers(
checkpoint = checkpoint,
vae_config = config,
image_size = image_size,
)
vae_model.save_pretrained(
output_path,
safe_serialization=is_safetensors_available()
checkpoint=checkpoint,
vae_config=config,
image_size=image_size,
)
vae_model.save_pretrained(output_path, safe_serialization=is_safetensors_available())
return output_path

View File

@ -1,77 +0,0 @@
'''
SafetyChecker class - checks images against the StabilityAI NSFW filter
and blurs images that contain potential NSFW content.
'''
import diffusers
import numpy as np
import torch
import traceback
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from pathlib import Path
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from .util import CPU_DEVICE
config = InvokeAIAppConfig.get_config()
class SafetyChecker(object):
CAUTION_IMG = "caution.png"
def __init__(self, device: torch.device):
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
self.device = device
try:
safety_model_id = config.models_path / 'core/convert/stable-diffusion-safety-checker'
feature_extractor_id = config.models_path / 'core/convert/stable-diffusion-safety-checker-extractor'
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_id)
except Exception:
logger.error(
"An error was encountered while installing the safety checker:"
)
print(traceback.format_exc())
def check(self, image: Image.Image):
"""
Check provided image against the StabilityAI safety checker and return
"""
self.safety_checker.to(self.device)
features = self.safety_feature_extractor([image], return_tensors="pt")
features.to(self.device)
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = self.safety_checker(
images=x_image, clip_input=features.pixel_values
)
self.safety_checker.to(CPU_DEVICE) # offload
if has_nsfw_concept[0]:
logger.warning(
"An image with potential non-safe content has been detected. A blurred image will be returned."
)
return self.blur(image)
else:
return image
def blur(self, input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
if caution := self.caution_img:
blurry.paste(caution, (0, 0), caution)
except FileNotFoundError:
pass
return blurry

View File

@ -47,6 +47,7 @@ from .diffusion import (
)
from .offloading import FullyLoadedModelGroup, ModelGroup
@dataclass
class PipelineIntermediateState:
run_id: str
@ -72,7 +73,11 @@ class AddsMaskLatents:
initial_image_latents: torch.Tensor
def __call__(
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor, **kwargs,
self,
latents: torch.Tensor,
t: torch.Tensor,
text_embeddings: torch.Tensor,
**kwargs,
) -> torch.Tensor:
model_input = self.add_mask_channels(latents)
return self.forward(model_input, t, text_embeddings, **kwargs)
@ -80,12 +85,8 @@ class AddsMaskLatents:
def add_mask_channels(self, latents):
batch_size = latents.size(0)
# duplicate mask and latents for each batch
mask = einops.repeat(
self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size
)
image_latents = einops.repeat(
self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size
)
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
# add mask and image as additional channels
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
return model_input
@ -103,9 +104,7 @@ class AddsMaskGuidance:
noise: torch.Tensor
_debug: Optional[Callable] = None
def __call__(
self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning
) -> BaseOutput:
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
output_class = step_output.__class__ # We'll create a new one with masked data.
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
@ -116,11 +115,7 @@ class AddsMaskGuidance:
# Mask anything that has the same shape as prev_sample, return others as-is.
return output_class(
{
k: (
self.apply_mask(v, self._t_for_field(k, t))
if are_like_tensors(prev_sample, v)
else v
)
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v)
for k, v in step_output.items()
}
)
@ -132,9 +127,7 @@ class AddsMaskGuidance:
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(
self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size
)
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
if t.dim() == 0:
# some schedulers expect t to be one-dimensional.
# TODO: file diffusers bug about inconsistency?
@ -144,12 +137,8 @@ class AddsMaskGuidance:
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(
mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size
)
masked_input = torch.lerp(
mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)
)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
if self._debug:
self._debug(masked_input, f"t={t} lerped")
return masked_input
@ -159,9 +148,7 @@ def trim_to_multiple_of(*args, multiple_of=8):
return tuple((x - x % multiple_of) for x in args)
def image_resized_to_grid_as_tensor(
image: PIL.Image.Image, normalize: bool = True, multiple_of=8
) -> torch.FloatTensor:
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor:
"""
:param image: input image
@ -211,6 +198,7 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
raise AssertionError("why was that an empty generator?")
return result
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
@ -341,9 +329,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# FIXME: can't currently register control module
# control_model=control_model,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(
self.unet, self._unet_forward
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
self._model_group.install(*self._submodels)
@ -354,11 +340,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if xformers is available, use it, otherwise use sliced attention.
"""
config = InvokeAIAppConfig.get_config()
if (
torch.cuda.is_available()
and is_xformers_available()
and not config.disable_xformers
):
if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers:
self.enable_xformers_memory_efficient_attention()
else:
if self.device.type == "cpu" or self.device.type == "mps":
@ -369,9 +351,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
raise ValueError(f"unrecognized device {self.device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = (
latents.element_size() + 4
)
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
max_size_required_for_baddbmm = (
16
* latents.size(dim=2)
@ -380,9 +360,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
* latents.size(dim=3)
* bytes_per_element_needed_for_baddbmm_duplication
)
if max_size_required_for_baddbmm > (
mem_free * 3.0 / 4.0
): # 3.3 / 4.0 is from old Invoke code
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
self.enable_attention_slicing(slice_size="max")
elif torch.backends.mps.is_available():
# diffusers recommends always enabling for mps
@ -470,7 +448,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_data: List[ControlNetData] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
scheduler_device = torch.device("cpu")
else:
scheduler_device = self._model_group.device_for(self.unet)
@ -488,7 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=run_id,
additional_guidance=additional_guidance,
control_data=control_data,
callback=callback,
)
return result.latents, result.attention_map_saver
@ -511,9 +488,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
):
yield PipelineIntermediateState(
run_id=run_id,
@ -607,16 +584,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# that are combined at higher level to make control_mode enum
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
# or default weighting (if False)
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
if cfg_injection:
control_latent_input = unet_latent_input
else:
@ -629,7 +605,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
encoder_hidden_states = conditioning_data.text_embeddings
encoder_attention_mask = None
else:
encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
(
encoder_hidden_states,
encoder_attention_mask,
) = self.invokeai_diffuser._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings,
)
@ -646,9 +625,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
@ -678,13 +657,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
)
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(
noise_pred, timestep, latents, **conditioning_data.scheduler_args
)
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
# But the way things are now, scheduler runs _after_ that, so there was
@ -710,17 +687,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# use of AddsMaskLatents.
latents = AddsMaskLatents(
self._unet_forward,
mask=torch.ones_like(
latents[:1, :1], device=latents.device, dtype=latents.dtype
),
initial_image_latents=torch.zeros_like(
latents[:1], device=latents.device, dtype=latents.dtype
),
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype),
).add_mask_channels(latents)
# First three args should be positional, not keywords, so torch hooks can see them.
return self.unet(
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
latents,
t,
text_embeddings,
cross_attention_kwargs=cross_attention_kwargs,
**kwargs,
).sample
@ -774,9 +750,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents if strength < 1.0 else torch.zeros_like(
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
),
latents=initial_latents
if strength < 1.0
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
timesteps=timesteps,
@ -797,14 +773,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def get_img2img_timesteps(
self, num_inference_steps: int, strength: float, device=None
) -> (torch.Tensor, int):
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
scheduler_device = torch.device("cpu")
else:
scheduler_device = self._model_group.device_for(self.unet)
@ -849,18 +823,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# 6. Prepare latent variables
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function
init_image_latents = self.non_noised_latents_from_image(
init_image, device=device, dtype=latents_dtype
)
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
if seed is not None:
set_seed(seed)
noise = noise_func(init_image_latents)
if mask.dim() == 3:
mask = mask.unsqueeze(0)
latent_mask = tv_resize(
mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR
).to(device=device, dtype=latents_dtype)
latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR).to(
device=device, dtype=latents_dtype
)
guidance: List[Callable] = []
@ -868,22 +840,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
# (that's why there's a mask!) but it seems to really want that blanked out.
masked_init_image = init_image * torch.where(mask < 0.5, 1, 0)
masked_latents = self.non_noised_latents_from_image(
masked_init_image, device=device, dtype=latents_dtype
)
masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype)
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
self._unet_forward, latent_mask, masked_latents
)
else:
guidance.append(
AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise)
)
guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise))
try:
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=init_image_latents if strength < 1.0 else torch.zeros_like(
latents=init_image_latents
if strength < 1.0
else torch.zeros_like(
init_image_latents, device=init_image_latents.device, dtype=init_image_latents.dtype
),
num_inference_steps=num_inference_steps,
@ -914,18 +884,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with torch.inference_mode():
self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(
dtype=dtype
) # FIXME: uses torch.randn. make reproducible!
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
init_latents = 0.18215 * init_latents
return init_latents
def check_for_safety(self, output, dtype):
with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker(
output.images, dtype=dtype
)
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver
@ -949,9 +915,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def debug_latents(self, latents, msg):
from invokeai.backend.image_util import debug_image
with torch.inference_mode():
decoded = self.numpy_to_pil(self.decode_latents(latents))
for i, img in enumerate(decoded):
debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
)
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)

View File

@ -17,6 +17,7 @@ from torch import nn
import invokeai.backend.util.logging as logger
from ...util import torch_dtype
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
@ -55,9 +56,7 @@ class Context:
if name in self.self_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
self.self_cross_attention_module_identifiers.append(name)
for name, module in get_cross_attention_modules(
model, CrossAttentionType.TOKENS
):
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
self.tokens_cross_attention_module_identifiers.append(name)
@ -68,9 +67,7 @@ class Context:
else:
self.tokens_cross_attention_action = Context.Action.SAVE
def request_apply_saved_attention_maps(
self, cross_attention_type: CrossAttentionType
):
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.APPLY
else:
@ -139,9 +136,7 @@ class Context:
saved_attention_dict = self.saved_cross_attention_maps[identifier]
if requested_dim is None:
if saved_attention_dict["dim"] is not None:
raise RuntimeError(
f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}"
)
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
return saved_attention_dict["slices"][0]
if saved_attention_dict["dim"] == requested_dim:
@ -154,21 +149,13 @@ class Context:
if saved_attention_dict["dim"] is None:
whole_saved_attention = saved_attention_dict["slices"][0]
if requested_dim == 0:
return whole_saved_attention[
requested_offset : requested_offset + slice_size
]
return whole_saved_attention[requested_offset : requested_offset + slice_size]
elif requested_dim == 1:
return whole_saved_attention[
:, requested_offset : requested_offset + slice_size
]
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
raise RuntimeError(
f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}"
)
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
def get_slicing_strategy(
self, identifier: str
) -> tuple[Optional[int], Optional[int]]:
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
if saved_attention is None:
return None, None
@ -201,9 +188,7 @@ class InvokeAICrossAttentionMixin:
def set_attention_slice_wrangler(
self,
wrangler: Optional[
Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]
],
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
):
"""
Set custom attention calculator to be called when attention is calculated
@ -219,14 +204,10 @@ class InvokeAICrossAttentionMixin:
"""
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(
self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]
):
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
self.slicing_strategy_getter = getter
def set_attention_slice_calculated_callback(
self, callback: Optional[Callable[[torch.Tensor], None]]
):
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
self.attention_slice_calculated_callback = callback
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
@ -247,45 +228,31 @@ class InvokeAICrossAttentionMixin:
)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(
dim=-1, dtype=attention_scores.dtype
)
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(
self, default_attention_slice, dim, offset, slice_size
)
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
if self.attention_slice_calculated_callback is not None:
self.attention_slice_calculated_callback(
attention_slice, dim, offset, slice_size
)
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
hidden_states = torch.bmm(attention_slice, value)
return hidden_states
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(
q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype
)
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(
q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size
)
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(
q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype
)
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(
q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size
)
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
@ -353,6 +320,7 @@ def restore_default_cross_attention(
else:
remove_attention_function(model)
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -372,7 +340,7 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
@ -386,16 +354,14 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
slice_size = next(
(p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size
)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
def get_cross_attention_modules(
model, which: CrossAttentionType
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
cross_attention_class: type = (
InvokeAIDiffusersCrossAttention
)
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
cross_attention_class: type = InvokeAIDiffusersCrossAttention
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
attention_module_tuples = [
(name, module)
@ -420,9 +386,7 @@ def get_cross_attention_modules(
def inject_attention_function(unet, context: Context):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(
module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size
):
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
attention_slice = suggested_attention_slice
@ -430,9 +394,7 @@ def inject_attention_function(unet, context: Context):
if context.get_should_save_maps(module.identifier):
# print(module.identifier, "saving suggested_attention_slice of shape",
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
slice_to_save = (
attention_slice.to("cpu") if dim is not None else attention_slice
)
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
context.save_slice(
module.identifier,
slice_to_save,
@ -442,31 +404,20 @@ def inject_attention_function(unet, context: Context):
)
elif context.get_should_apply_saved_maps(module.identifier):
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
saved_attention_slice = context.get_slice(
module.identifier, dim, offset, slice_size
)
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
# slice may have been offloaded to CPU
saved_attention_slice = saved_attention_slice.to(
suggested_attention_slice.device
)
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
if context.is_tokens_cross_attention(module.identifier):
index_map = context.cross_attention_index_map
remapped_saved_attention_slice = torch.index_select(
saved_attention_slice, -1, index_map
)
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
this_attention_slice = suggested_attention_slice
mask = context.cross_attention_mask.to(
torch_dtype(suggested_attention_slice.device)
)
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
saved_mask = mask
this_mask = 1 - mask
attention_slice = (
remapped_saved_attention_slice * saved_mask
+ this_attention_slice * this_mask
)
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
else:
# just use everything
attention_slice = saved_attention_slice
@ -480,14 +431,10 @@ def inject_attention_function(unet, context: Context):
module.identifier = identifier
try:
module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(
lambda module: context.get_slicing_strategy(identifier)
)
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(
f"TODO: implement set_attention_slice_wrangler for {type(module)}"
) # TODO
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
else:
raise
@ -503,9 +450,7 @@ def remove_attention_function(unet):
module.set_slicing_strategy_getter(None)
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(
f"TODO: implement set_attention_slice_wrangler for {type(module)}"
)
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
else:
raise
@ -530,9 +475,7 @@ def get_mem_free_total(device):
return mem_free_total
class InvokeAIDiffusersCrossAttention(
diffusers.models.attention.Attention, InvokeAICrossAttentionMixin
):
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
InvokeAICrossAttentionMixin.__init__(self)
@ -641,11 +584,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# kwargs
swap_cross_attn_context: SwapCrossAttnContext = None,
):
attention_type = (
CrossAttentionType.SELF
if encoder_hidden_states is None
else CrossAttentionType.TOKENS
)
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
# if cross-attention control is not in play, just call through to the base implementation.
if (
@ -654,9 +593,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
):
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
return super().__call__(
attn, hidden_states, encoder_hidden_states, attention_mask
)
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
# else:
# print(f"SwapCrossAttnContext for {attention_type} active")
@ -699,18 +636,10 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
query_slice = query[start_idx:end_idx]
original_key_slice = original_text_key[start_idx:end_idx]
modified_key_slice = modified_text_key[start_idx:end_idx]
attn_mask_slice = (
attention_mask[start_idx:end_idx]
if attention_mask is not None
else None
)
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
original_attn_slice = attn.get_attention_scores(
query_slice, original_key_slice, attn_mask_slice
)
modified_attn_slice = attn.get_attention_scores(
query_slice, modified_key_slice, attn_mask_slice
)
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
# because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the
@ -722,9 +651,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = swap_cross_attn_context.mask
inverse_mask = 1 - mask
attn_slice = (
remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
)
attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
del remapped_original_attn_slice, modified_attn_slice
@ -744,6 +671,4 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
def __init__(self):
super(SwapCrossAttnProcessor, self).__init__(
slice_size=int(1e9)
) # massive slice size = don't slice
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice

View File

@ -59,9 +59,7 @@ class AttentionMapSaver:
for key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(
maps.shape[0] / (latents_width * latents_height)
)
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
this_maps_height = int(float(latents_height) * this_scale_factor)
this_maps_width = int(float(latents_width) * this_scale_factor)
# and we need to do some dimension juggling
@ -72,9 +70,7 @@ class AttentionMapSaver:
# scale to output size if necessary
if this_scale_factor != 1:
maps = tv_resize(
maps, [latents_height, latents_width], InterpolationMode.BICUBIC
)
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
# normalize
maps_min = torch.min(maps)
@ -83,9 +79,7 @@ class AttentionMapSaver:
maps_normalized = (maps - maps_min) / maps_range
# expand to (-0.1, 1.1) and clamp
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
maps_normalized_expanded_clamped = torch.clamp(
maps_normalized_expanded, 0, 1
)
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
# merge together, producing a vertical stack
maps_stacked = torch.reshape(

View File

@ -31,6 +31,7 @@ ModelForwardCallback: TypeAlias = Union[
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
@ -81,14 +82,12 @@ class InvokeAIDiffuserComponent:
@contextmanager
def custom_attention_context(
cls,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int
step_count: int,
):
old_attn_processors = None
if extra_conditioning_info and (
extra_conditioning_info.wants_cross_attention_control
):
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
@ -116,27 +115,15 @@ class InvokeAIDiffuserComponent:
return
saver.add_attention_maps(slice, key)
tokens_cross_attention_modules = get_cross_attention_modules(
self.model, CrossAttentionType.TOKENS
)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for identifier, module in tokens_cross_attention_modules:
key = (
"down"
if identifier.startswith("down")
else "up"
if identifier.startswith("up")
else "mid"
)
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
module.set_attention_slice_calculated_callback(
lambda slice, dim, offset, slice_size, key=key: callback(
slice, dim, offset, slice_size, key
)
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
)
def remove_attention_map_saving(self):
tokens_cross_attention_modules = get_cross_attention_modules(
self.model, CrossAttentionType.TOKENS
)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
@ -171,10 +158,8 @@ class InvokeAIDiffuserComponent:
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = (
context.get_active_cross_attention_control_types_for_step(
percent_through
)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
percent_through
)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
@ -182,7 +167,11 @@ class InvokeAIDiffuserComponent:
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x, sigma, unconditioning, conditioning, **kwargs,
x,
sigma,
unconditioning,
conditioning,
**kwargs,
)
elif wants_cross_attention_control:
(
@ -201,7 +190,11 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning, **kwargs,
x,
sigma,
unconditioning,
conditioning,
**kwargs,
)
else:
@ -209,12 +202,18 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning, **kwargs,
x,
sigma,
unconditioning,
conditioning,
**kwargs,
)
combined_next_x = self._combine(
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
unconditioned_next_x, conditioned_next_x, guidance_scale
unconditioned_next_x,
conditioned_next_x,
guidance_scale,
)
return combined_next_x
@ -229,37 +228,47 @@ class InvokeAIDiffuserComponent:
) -> torch.Tensor:
if postprocessing_settings is not None:
percent_through = step_index / total_step_count
latents = self.apply_threshold(
postprocessing_settings, latents, percent_through
)
latents = self.apply_symmetry(
postprocessing_settings, latents, percent_through
)
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
return latents
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
conditioning_attention_mask = torch.ones(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat([
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
], dim=1)
conditioning_attention_mask = torch.cat(
[
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
],
dim=1,
)
cond = torch.cat([
cond,
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
], dim=1)
cond = torch.cat(
[
cond,
torch.zeros(
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
device=cond.device,
dtype=cond.dtype,
),
],
dim=1,
)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([
encoder_attention_mask,
conditioning_attention_mask,
])
encoder_attention_mask = torch.cat(
[
encoder_attention_mask,
conditioning_attention_mask,
]
)
return cond, encoder_attention_mask
encoder_attention_mask = None
@ -277,11 +286,11 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
unconditioning, conditioning
)
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(unconditioning, conditioning)
both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings,
x_twice,
sigma_twice,
both_conditionings,
encoder_attention_mask=encoder_attention_mask,
**kwargs,
)
@ -312,13 +321,17 @@ class InvokeAIDiffuserComponent:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
unconditioned_next_x = self.model_forward_callback(
x, sigma, unconditioning,
x,
sigma,
unconditioning,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
**kwargs,
)
conditioned_next_x = self.model_forward_callback(
x, sigma, conditioning,
x,
sigma,
conditioning,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
**kwargs,
@ -335,13 +348,15 @@ class InvokeAIDiffuserComponent:
for k in conditioning:
if isinstance(conditioning[k], list):
both_conditionings[k] = [
torch.cat([unconditioning[k][i], conditioning[k][i]])
for i in range(len(conditioning[k]))
torch.cat([unconditioning[k][i], conditioning[k][i]]) for i in range(len(conditioning[k]))
]
else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings, **kwargs,
x_twice,
sigma_twice,
both_conditionings,
**kwargs,
).chunk(2)
return unconditioned_next_x, conditioned_next_x
@ -388,9 +403,7 @@ class InvokeAIDiffuserComponent:
)
# do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = (
cross_attention_control_types_to_do
)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback(
x,
sigma,
@ -414,19 +427,14 @@ class InvokeAIDiffuserComponent:
latents: torch.Tensor,
percent_through: float,
) -> torch.Tensor:
if (
postprocessing_settings.threshold is None
or postprocessing_settings.threshold == 0.0
):
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
return latents
threshold = postprocessing_settings.threshold
warmup = postprocessing_settings.warmup
if percent_through < warmup:
current_threshold = threshold + threshold * 5 * (
1 - (percent_through / warmup)
)
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
else:
current_threshold = threshold
@ -440,18 +448,10 @@ class InvokeAIDiffuserComponent:
if self.debug_thresholding:
std, mean = [i.item() for i in torch.std_mean(latents)]
outside = torch.count_nonzero(
(latents < -current_threshold) | (latents > current_threshold)
)
logger.info(
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
)
logger.debug(
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
)
logger.debug(
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
)
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
logger.info(f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})")
logger.debug(f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}")
logger.debug(f"{outside / latents.numel() * 100:.2f}% values outside threshold")
if maxval < current_threshold and minval > -current_threshold:
return latents
@ -464,25 +464,17 @@ class InvokeAIDiffuserComponent:
latents = torch.clone(latents)
maxval = np.clip(maxval * scale, 1, current_threshold)
num_altered += torch.count_nonzero(latents > maxval)
latents[latents > maxval] = (
torch.rand_like(latents[latents > maxval]) * maxval
)
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
if minval < -current_threshold:
latents = torch.clone(latents)
minval = np.clip(minval * scale, -current_threshold, -1)
num_altered += torch.count_nonzero(latents < minval)
latents[latents < minval] = (
torch.rand_like(latents[latents < minval]) * minval
)
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
if self.debug_thresholding:
logger.debug(
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
)
logger.debug(
f"{num_altered / latents.numel() * 100:.2f}% values altered"
)
logger.debug(f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})")
logger.debug(f"{num_altered / latents.numel() * 100:.2f}% values altered")
return latents
@ -501,15 +493,11 @@ class InvokeAIDiffuserComponent:
# Check for out of bounds
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
if h_symmetry_time_pct is not None and (
h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0
):
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
h_symmetry_time_pct = None
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
if v_symmetry_time_pct is not None and (
v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0
):
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
v_symmetry_time_pct = None
dev = latents.device.type
@ -554,9 +542,7 @@ class InvokeAIDiffuserComponent:
def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended)
return float(step_index) / float(
self.cross_attention_control_context.step_count
)
return float(step_index) / float(self.cross_attention_control_context.step_count)
# find the best possible index of the current sigma in the sigma sequence
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
@ -567,19 +553,13 @@ class InvokeAIDiffuserComponent:
# todo: make this work
@classmethod
def apply_conjunction(
cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale
):
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas
deltas = None
uncond_latents = None
weighted_cond_list = (
c_or_weighted_c_list
if type(c_or_weighted_c_list) is list
else [(c_or_weighted_c_list, 1)]
)
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
# below is fugly omg
conditionings = [uc] + [c for c, weight in weighted_cond_list]
@ -608,15 +588,11 @@ class InvokeAIDiffuserComponent:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(
weights[1:], dtype=deltas.dtype, device=deltas.device
)
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(
per_delta_weights.shape + (1, 1, 1)
)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)

View File

@ -261,9 +261,7 @@ def srmd_degradation(x, k, sf=3):
year={2018}
}
"""
x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode="wrap"
) # 'nearest' | 'mirror'
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
@ -389,21 +387,15 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else: # add noise
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
@ -413,21 +405,15 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else:
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
@ -440,9 +426,7 @@ def add_Poisson_noise(img):
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
)
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
@ -451,9 +435,7 @@ def add_Poisson_noise(img):
def add_JPEG_noise(img):
quality_factor = random.randint(30, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode(
".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
@ -540,9 +522,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode="mirror"
)
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
@ -646,9 +626,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode="mirror"
)
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
@ -796,9 +774,7 @@ if __name__ == "__main__":
print(i)
img_lq = deg_fn(img)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img)["image"]
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
@ -812,7 +788,5 @@ if __name__ == "__main__":
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
util.imsave(img_concat, str(i) + ".png")

View File

@ -261,9 +261,7 @@ def srmd_degradation(x, k, sf=3):
year={2018}
}
"""
x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode="wrap"
) # 'nearest' | 'mirror'
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
@ -393,21 +391,15 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else: # add noise
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
@ -417,21 +409,15 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else:
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
@ -444,9 +430,7 @@ def add_Poisson_noise(img):
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
)
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
@ -455,9 +439,7 @@ def add_Poisson_noise(img):
def add_JPEG_noise(img):
quality_factor = random.randint(80, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode(
".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
@ -544,9 +526,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode="mirror"
)
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
@ -653,9 +633,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode="mirror"
)
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
@ -705,9 +683,9 @@ if __name__ == "__main__":
img_lq = deg_fn(img)["image"]
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img_hq)["image"]
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[
"image"
]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
@ -721,7 +699,5 @@ if __name__ == "__main__":
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
util.imsave(img_concat, str(i) + ".png")

View File

@ -11,6 +11,7 @@ from torchvision.utils import make_grid
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
import invokeai.backend.util.logging as logger
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
@ -296,22 +297,14 @@ def single2uint16(img):
def uint2tensor4(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.div(255.0)
.unsqueeze(0)
)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0).unsqueeze(0)
# convert uint to 3-dimensional torch tensor
def uint2tensor3(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return (
torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0)
)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0)
# convert 2/3/4-dimensional torch tensor to uint
@ -334,12 +327,7 @@ def single2tensor3(img):
# convert single (HxWxC) to 4-dimensional torch tensor
def single2tensor4(img):
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.unsqueeze(0)
)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
# convert torch tensor to single
@ -362,12 +350,7 @@ def tensor2single3(img):
def single2tensor5(img):
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1, 3)
.float()
.unsqueeze(0)
)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
def single32tensor5(img):
@ -385,9 +368,7 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
"""
tensor = (
tensor.squeeze().float().cpu().clamp_(*min_max)
) # squeeze first, then clamp
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
n_dim = tensor.dim()
if n_dim == 4:
@ -400,11 +381,7 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
elif n_dim == 2:
img_np = tensor.numpy()
else:
raise TypeError(
"Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(
n_dim
)
)
raise TypeError("Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(n_dim))
if out_type == np.uint8:
img_np = (img_np * 255.0).round()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
@ -744,9 +721,7 @@ def ssim(img1, img2):
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
)
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
@ -767,9 +742,7 @@ def cubic(x):
) * (((absx > 1) * (absx <= 2)).type_as(absx))
def calculate_weights_indices(
in_length, out_length, scale, kernel, kernel_width, antialiasing
):
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale
@ -793,9 +766,9 @@ def calculate_weights_indices(
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
0, P - 1, P
).view(1, P).expand(out_length, P)
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand(
out_length, P
)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
@ -876,9 +849,7 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[j, i, :] = (
img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
)
out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
@ -959,9 +930,7 @@ def imresize_np(img, scale, antialiasing=True):
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[i, :, j] = (
img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
)
out_1[i, :, j] = img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying

View File

@ -95,10 +95,7 @@ class ModelGroup(metaclass=ABCMeta):
pass
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} object at {id(self):x}: "
f"device={self.execution_device} >"
)
return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >"
class LazilyLoadedModelGroup(ModelGroup):
@ -143,8 +140,7 @@ class LazilyLoadedModelGroup(ModelGroup):
self.load(module)
if len(forward_input) == 0:
warnings.warn(
f"Hook for {module.__class__.__name__} got no input. "
f"Inputs must be positional, not keywords.",
f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.",
stacklevel=3,
)
return send_to_device(forward_input, self.execution_device)
@ -161,9 +157,7 @@ class LazilyLoadedModelGroup(ModelGroup):
self.clear_current_model()
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
assert (
self.is_empty()
), f"A model is already loaded: {self._current_model_ref()}"
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
module = module.to(self.execution_device)
self.set_current_model(module)
return module
@ -192,12 +186,8 @@ class LazilyLoadedModelGroup(ModelGroup):
def device_for(self, model):
if model not in self:
raise KeyError(
f"This does not manage this model {type(model).__name__}", model
)
return (
self.execution_device
) # this implementation only dispatches to one device
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def ready(self):
pass # always ready to load on-demand
@ -256,12 +246,8 @@ class FullyLoadedModelGroup(ModelGroup):
def device_for(self, model):
if model not in self:
raise KeyError(
"This does not manage this model f{type(model).__name__}", model
)
return (
self.execution_device
) # this implementation only dispatches to one device
raise KeyError("This does not manage this model f{type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def __contains__(self, model):
return model in self._models

View File

@ -1 +1 @@
from .schedulers import SCHEDULER_MAP
from .schedulers import SCHEDULER_MAP

View File

@ -1,7 +1,19 @@
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSDEScheduler
from diffusers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
HeunDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UniPCMultistepScheduler,
DPMSolverSinglestepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSDEScheduler,
)
SCHEDULER_MAP = dict(
ddim=(DDIMScheduler, dict()),
@ -21,9 +33,9 @@ SCHEDULER_MAP = dict(
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type='sde-dpmsolver++')),
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type='sde-dpmsolver++')),
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type="sde-dpmsolver++")),
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")),
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
unipc=(UniPCMultistepScheduler, dict(cpu_only=True)),
)

View File

@ -45,7 +45,7 @@ from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
# invokeai stuff
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
from invokeai.app.services.model_manager_service import ModelManagerService
from invokeai.backend.model_management.models import SubModelType
@ -75,24 +75,16 @@ check_min_version("0.10.0.dev0")
logger = get_logger(__name__)
def save_progress(
text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path
):
def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path):
logger.info("Saving embeddings")
learned_embeds = (
accelerator.unwrap_model(text_encoder)
.get_input_embeddings()
.weight[placeholder_token_id]
)
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, save_path)
def parse_args():
config = InvokeAIAppConfig.get_config()
parser = PagingArgumentParser(
description="Textual inversion training"
)
parser = PagingArgumentParser(description="Textual inversion training")
general_group = parser.add_argument_group("General")
model_group = parser.add_argument_group("Models and Paths")
image_group = parser.add_argument_group("Training Image Location and Options")
@ -221,9 +213,7 @@ def parse_args():
default=100,
help="How many times to repeat the training data.",
)
training_group.add_argument(
"--seed", type=int, default=None, help="A seed for reproducible training."
)
training_group.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
training_group.add_argument(
"--train_batch_size",
type=int,
@ -287,9 +277,7 @@ def parse_args():
default=0.999,
help="The beta2 parameter for the Adam optimizer.",
)
training_group.add_argument(
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
)
training_group.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
training_group.add_argument(
"--adam_epsilon",
type=float,
@ -442,9 +430,7 @@ class TextualInversionDataset(Dataset):
self.data_root / file_path
for file_path in self.data_root.iterdir()
if file_path.is_file()
and file_path.name.endswith(
(".png", ".PNG", ".jpg", ".JPG", ".jpeg", ".JPEG", ".gif", ".GIF")
)
and file_path.name.endswith((".png", ".PNG", ".jpg", ".JPG", ".jpeg", ".JPEG", ".gif", ".GIF"))
]
self.num_images = len(self.image_paths)
@ -460,11 +446,7 @@ class TextualInversionDataset(Dataset):
"lanczos": PIL_INTERPOLATION["lanczos"],
}[interpolation]
self.templates = (
imagenet_style_templates_small
if learnable_property == "style"
else imagenet_templates_small
)
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
def __len__(self):
@ -500,9 +482,7 @@ class TextualInversionDataset(Dataset):
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2
]
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
image = Image.fromarray(img)
image = image.resize((self.size, self.size), resample=self.interpolation)
@ -515,9 +495,7 @@ class TextualInversionDataset(Dataset):
return example
def get_full_repo_name(
model_id: str, organization: Optional[str] = None, token: Optional[str] = None
):
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
@ -570,9 +548,7 @@ def do_textual_inversion_training(
**kwargs,
):
assert model, "Please specify a base model with --model"
assert (
train_data_dir
), "Please specify a directory containing the training images using --train_data_dir"
assert train_data_dir, "Please specify a directory containing the training images using --train_data_dir"
assert placeholder_token, "Please specify a trigger term using --placeholder_token"
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != local_rank:
@ -593,7 +569,7 @@ def do_textual_inversion_training(
project_config=accelerator_config,
)
model_manager = ModelManagerService(config,logger)
model_manager = ModelManagerService(config, logger)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
@ -633,13 +609,11 @@ def do_textual_inversion_training(
os.makedirs(output_dir, exist_ok=True)
known_models = model_manager.model_names()
model_name = model.split('/')[-1]
model_name = model.split("/")[-1]
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
assert model_meta is not None, f"Unknown model: {model}"
model_info = model_manager.model_info(*model_meta)
assert (
model_info['model_format'] == "diffusers"
), "This script only works with models of type 'diffusers'"
assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'"
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
@ -650,9 +624,7 @@ def do_textual_inversion_training(
if tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
else:
tokenizer = CLIPTokenizer.from_pretrained(
tokenizer_info.location, subfolder='tokenizer', **pipeline_args
)
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_info.location, subfolder="tokenizer", **pipeline_args)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(
@ -722,9 +694,7 @@ def do_textual_inversion_training(
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError(
"xformers is not available. Make sure it is installed correctly"
)
raise ValueError("xformers is not available. Make sure it is installed correctly")
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
@ -732,12 +702,7 @@ def do_textual_inversion_training(
torch.backends.cuda.matmul.allow_tf32 = True
if scale_lr:
learning_rate = (
learning_rate
* gradient_accumulation_steps
* train_batch_size
* accelerator.num_processes
)
learning_rate = learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
# Initialize the optimizer
optimizer = torch.optim.AdamW(
@ -759,15 +724,11 @@ def do_textual_inversion_training(
center_crop=center_crop,
set="train",
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True
)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
if max_train_steps is None:
max_train_steps = num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
@ -797,9 +758,7 @@ def do_textual_inversion_training(
vae.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
if overrode_max_train_steps:
max_train_steps = num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
@ -814,17 +773,13 @@ def do_textual_inversion_training(
accelerator.init_trackers("textual_inversion", config=params)
# Train!
total_batch_size = (
train_batch_size * accelerator.num_processes * gradient_accumulation_steps
)
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_train_steps}")
global_step = 0
@ -843,9 +798,7 @@ def do_textual_inversion_training(
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
)
accelerator.print(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
@ -854,9 +807,7 @@ def do_textual_inversion_training(
resume_global_step = global_step * gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (
num_update_steps_per_epoch * gradient_accumulation_steps
)
resume_step = resume_global_step % (num_update_steps_per_epoch * gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(
@ -866,33 +817,20 @@ def do_textual_inversion_training(
progress_bar.set_description("Steps")
# keep original embeddings as reference
orig_embeds_params = (
accelerator.unwrap_model(text_encoder)
.get_input_embeddings()
.weight.data.clone()
)
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
for epoch in range(first_epoch, num_train_epochs):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if (
resume_step
and resume_from_checkpoint
and epoch == first_epoch
and step < resume_step
):
if resume_step and resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = (
vae.encode(batch["pixel_values"].to(dtype=weight_dtype))
.latent_dist.sample()
.detach()
)
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
@ -912,14 +850,10 @@ def do_textual_inversion_training(
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(
dtype=weight_dtype
)
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
# Predict the noise residual
model_pred = unet(
noisy_latents, timesteps, encoder_hidden_states
).sample
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
@ -927,9 +861,7 @@ def do_textual_inversion_training(
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
)
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
@ -942,22 +874,16 @@ def do_textual_inversion_training(
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad():
accelerator.unwrap_model(
text_encoder
).get_input_embeddings().weight[
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[
index_no_updates
]
] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % save_steps == 0:
save_path = os.path.join(
output_dir, f"learned_embeds-steps-{global_step}.bin"
)
save_path = os.path.join(output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(
text_encoder,
placeholder_token_id,
@ -968,9 +894,7 @@ def do_textual_inversion_training(
if global_step % checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(
output_dir, f"checkpoint-{global_step}"
)
save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
@ -985,9 +909,7 @@ def do_textual_inversion_training(
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if push_to_hub and only_save_embeds:
logger.warn(
"Enabling full model saving because --push_to_hub=True was specified."
)
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
save_full_model = True
else:
save_full_model = not only_save_embeds
@ -1012,8 +934,6 @@ def do_textual_inversion_training(
)
if push_to_hub:
repo.push_to_hub(
commit_message="End of training", blocking=False, auto_lfs_prune=True
)
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training()

View File

@ -11,12 +11,4 @@ from .devices import (
torch_dtype,
)
from .log import write_log
from .util import (
ask_user,
download_with_resume,
instantiate_from_config,
url_attachment_name,
Chdir
)
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir

View File

@ -12,6 +12,7 @@ CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
config = InvokeAIAppConfig.get_config()
def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on"""
if config.always_use_cpu:

View File

@ -20,6 +20,7 @@ from diffusers.models.controlnet import ControlNetConditioningEmbedding, Control
# Modified ControlNetModel with encoder_attention_mask argument added
class ControlNetModel(ModelMixin, ConfigMixin):
"""
A ControlNet model.
@ -618,9 +619,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
down_block_res_samples = [torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
@ -630,5 +629,6 @@ class ControlNetModel(ModelMixin, ConfigMixin):
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
)
diffusers.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel

View File

@ -1,7 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""
invokeai.util.logging
invokeai.backend.util.logging
Logging class for InvokeAI that produces console messages
@ -186,89 +186,109 @@ from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config
try:
import syslog
SYSLOG_AVAILABLE = True
except:
SYSLOG_AVAILABLE = False
# module level functions
def debug(msg, *args, **kwargs):
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
def info(msg, *args, **kwargs):
InvokeAILogger.getLogger().info(msg, *args, **kwargs)
def warning(msg, *args, **kwargs):
InvokeAILogger.getLogger().warning(msg, *args, **kwargs)
def error(msg, *args, **kwargs):
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
def critical(msg, *args, **kwargs):
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
def log(level, msg, *args, **kwargs):
InvokeAILogger.getLogger().log(level, msg, *args, **kwargs)
def disable(level=logging.CRITICAL):
InvokeAILogger.getLogger().disable(level)
def basicConfig(**kwargs):
InvokeAILogger.getLogger().basicConfig(**kwargs)
def getLogger(name: str = None) -> logging.Logger:
return InvokeAILogger.getLogger(name)
_FACILITY_MAP = dict(
LOG_KERN = syslog.LOG_KERN,
LOG_USER = syslog.LOG_USER,
LOG_MAIL = syslog.LOG_MAIL,
LOG_DAEMON = syslog.LOG_DAEMON,
LOG_AUTH = syslog.LOG_AUTH,
LOG_LPR = syslog.LOG_LPR,
LOG_NEWS = syslog.LOG_NEWS,
LOG_UUCP = syslog.LOG_UUCP,
LOG_CRON = syslog.LOG_CRON,
LOG_SYSLOG = syslog.LOG_SYSLOG,
LOG_LOCAL0 = syslog.LOG_LOCAL0,
LOG_LOCAL1 = syslog.LOG_LOCAL1,
LOG_LOCAL2 = syslog.LOG_LOCAL2,
LOG_LOCAL3 = syslog.LOG_LOCAL3,
LOG_LOCAL4 = syslog.LOG_LOCAL4,
LOG_LOCAL5 = syslog.LOG_LOCAL5,
LOG_LOCAL6 = syslog.LOG_LOCAL6,
LOG_LOCAL7 = syslog.LOG_LOCAL7,
) if SYSLOG_AVAILABLE else dict()
_SOCK_MAP = dict(
SOCK_STREAM = socket.SOCK_STREAM,
SOCK_DGRAM = socket.SOCK_DGRAM,
_FACILITY_MAP = (
dict(
LOG_KERN=syslog.LOG_KERN,
LOG_USER=syslog.LOG_USER,
LOG_MAIL=syslog.LOG_MAIL,
LOG_DAEMON=syslog.LOG_DAEMON,
LOG_AUTH=syslog.LOG_AUTH,
LOG_LPR=syslog.LOG_LPR,
LOG_NEWS=syslog.LOG_NEWS,
LOG_UUCP=syslog.LOG_UUCP,
LOG_CRON=syslog.LOG_CRON,
LOG_SYSLOG=syslog.LOG_SYSLOG,
LOG_LOCAL0=syslog.LOG_LOCAL0,
LOG_LOCAL1=syslog.LOG_LOCAL1,
LOG_LOCAL2=syslog.LOG_LOCAL2,
LOG_LOCAL3=syslog.LOG_LOCAL3,
LOG_LOCAL4=syslog.LOG_LOCAL4,
LOG_LOCAL5=syslog.LOG_LOCAL5,
LOG_LOCAL6=syslog.LOG_LOCAL6,
LOG_LOCAL7=syslog.LOG_LOCAL7,
)
if SYSLOG_AVAILABLE
else dict()
)
_SOCK_MAP = dict(
SOCK_STREAM=socket.SOCK_STREAM,
SOCK_DGRAM=socket.SOCK_DGRAM,
)
class InvokeAIFormatter(logging.Formatter):
'''
"""
Base class for logging formatter
'''
"""
def format(self, record):
formatter = logging.Formatter(self.log_fmt(record.levelno))
return formatter.format(record)
@abstractmethod
def log_fmt(self, levelno: int)->str:
def log_fmt(self, levelno: int) -> str:
pass
class InvokeAISyslogFormatter(InvokeAIFormatter):
'''
"""
Formatting for syslog
'''
def log_fmt(self, levelno: int)->str:
return '%(name)s [%(process)d] <%(levelname)s> %(message)s'
"""
def log_fmt(self, levelno: int) -> str:
return "%(name)s [%(process)d] <%(levelname)s> %(message)s"
class InvokeAILegacyLogFormatter(InvokeAIFormatter):
'''
"""
Formatting for the InvokeAI Logger (legacy version)
'''
"""
FORMATS = {
logging.DEBUG: " | %(message)s",
logging.INFO: ">> %(message)s",
@ -276,20 +296,25 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter):
logging.ERROR: "*** %(message)s",
logging.CRITICAL: "### %(message)s",
}
def log_fmt(self,levelno:int)->str:
def log_fmt(self, levelno: int) -> str:
return self.FORMATS.get(levelno)
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
'''
"""
Custom Formatting for the InvokeAI Logger (plain version)
'''
def log_fmt(self, levelno: int)->str:
"""
def log_fmt(self, levelno: int) -> str:
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
class InvokeAIColorLogFormatter(InvokeAIFormatter):
'''
"""
Custom Formatting for the InvokeAI Logger
'''
"""
# Color Codes
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
@ -308,32 +333,34 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter):
logging.INFO: grey + log_format + reset,
logging.WARNING: yellow + log_format + reset,
logging.ERROR: red + log_format + reset,
logging.CRITICAL: bold_red + log_format + reset
logging.CRITICAL: bold_red + log_format + reset,
}
def log_fmt(self, levelno: int)->str:
def log_fmt(self, levelno: int) -> str:
return self.FORMATS.get(levelno)
LOG_FORMATTERS = {
'plain': InvokeAIPlainLogFormatter,
'color': InvokeAIColorLogFormatter,
'syslog': InvokeAISyslogFormatter,
'legacy': InvokeAILegacyLogFormatter,
"plain": InvokeAIPlainLogFormatter,
"color": InvokeAIColorLogFormatter,
"syslog": InvokeAISyslogFormatter,
"legacy": InvokeAILegacyLogFormatter,
}
class InvokeAILogger(object):
loggers = dict()
@classmethod
def getLogger(cls,
name: str = 'InvokeAI',
config: InvokeAIAppConfig=InvokeAIAppConfig.get_config())->logging.Logger:
def getLogger(
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
) -> logging.Logger:
if name in cls.loggers:
logger = cls.loggers[name]
logger.handlers.clear()
else:
logger = logging.getLogger(name)
logger.setLevel(config.log_level.upper()) # yes, strings work here
logger.setLevel(config.log_level.upper()) # yes, strings work here
for ch in cls.getLoggers(config):
logger.addHandler(ch)
cls.loggers[name] = logger
@ -344,82 +371,80 @@ class InvokeAILogger(object):
handler_strs = config.log_handlers
handlers = list()
for handler in handler_strs:
handler_name,*args = handler.split('=',2)
handler_name, *args = handler.split("=", 2)
args = args[0] if len(args) > 0 else None
# console and file get the fancy formatter.
# syslog gets a simple one
# http gets no custom formatter
formatter = LOG_FORMATTERS[config.log_format]
if handler_name=='console':
if handler_name == "console":
ch = logging.StreamHandler()
ch.setFormatter(formatter())
handlers.append(ch)
elif handler_name=='syslog':
elif handler_name == "syslog":
ch = cls._parse_syslog_args(args)
handlers.append(ch)
elif handler_name=='file':
elif handler_name == "file":
ch = cls._parse_file_args(args)
ch.setFormatter(formatter())
handlers.append(ch)
elif handler_name=='http':
elif handler_name == "http":
ch = cls._parse_http_args(args)
handlers.append(ch)
return handlers
@staticmethod
def _parse_syslog_args(
args: str=None
)-> logging.Handler:
def _parse_syslog_args(args: str = None) -> logging.Handler:
if not SYSLOG_AVAILABLE:
raise ValueError("syslog is not available on this system")
if not args:
args='/dev/log' if Path('/dev/log').exists() else 'address:localhost:514'
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
syslog_args = dict()
try:
for a in args.split(','):
arg_name,*arg_value = a.split(':',2)
if arg_name=='address':
host,*port = arg_value
port = 514 if len(port)==0 else int(port[0])
syslog_args['address'] = (host,port)
elif arg_name=='facility':
syslog_args['facility'] = _FACILITY_MAP[arg_value[0]]
elif arg_name=='socktype':
syslog_args['socktype'] = _SOCK_MAP[arg_value[0]]
for a in args.split(","):
arg_name, *arg_value = a.split(":", 2)
if arg_name == "address":
host, *port = arg_value
port = 514 if len(port) == 0 else int(port[0])
syslog_args["address"] = (host, port)
elif arg_name == "facility":
syslog_args["facility"] = _FACILITY_MAP[arg_value[0]]
elif arg_name == "socktype":
syslog_args["socktype"] = _SOCK_MAP[arg_value[0]]
else:
syslog_args['address'] = arg_name
syslog_args["address"] = arg_name
except:
raise ValueError(f"{args} is not a value argument list for syslog logging")
return logging.handlers.SysLogHandler(**syslog_args)
@staticmethod
def _parse_file_args(args: str=None)-> logging.Handler:
def _parse_file_args(args: str = None) -> logging.Handler:
if not args:
raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'")
return logging.FileHandler(args)
@staticmethod
def _parse_http_args(args: str=None)-> logging.Handler:
def _parse_http_args(args: str = None) -> logging.Handler:
if not args:
raise ValueError("please provide destination for http logging using format 'http=url'")
arg_list = args.split(',')
arg_list = args.split(",")
url = urllib.parse.urlparse(arg_list.pop(0))
if url.scheme != 'http':
if url.scheme != "http":
raise ValueError(f"the http logging module can only log to HTTP URLs, but {url.scheme} was specified")
host = url.hostname
path = url.path
port = url.port or 80
syslog_args = dict()
for a in arg_list:
arg_name, *arg_value = a.split(':',2)
if arg_name=='method':
arg_value = arg_value[0] if len(arg_value)>0 else 'GET'
arg_name, *arg_value = a.split(":", 2)
if arg_name == "method":
arg_value = arg_value[0] if len(arg_value) > 0 else "GET"
syslog_args[arg_name] = arg_value
else: # TODO: Provide support for SSL context and credentials
pass
return logging.handlers.HTTPHandler(f'{host}:{port}',path,**syslog_args)
return logging.handlers.HTTPHandler(f"{host}:{port}", path, **syslog_args)

View File

@ -8,6 +8,8 @@ if torch.backends.mps.is_available():
_torch_layer_norm = torch.nn.functional.layer_norm
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
@ -19,20 +21,26 @@ def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
else:
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
torch.nn.functional.layer_norm = new_layer_norm
_torch_tensor_permute = torch.Tensor.permute
def new_torch_tensor_permute(input, *dims):
result = _torch_tensor_permute(input, *dims)
if input.device == "mps" and input.dtype == torch.float16:
result = result.contiguous()
return result
torch.Tensor.permute = new_torch_tensor_permute
_torch_lerp = torch.lerp
def new_torch_lerp(input, end, weight, *, out=None):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
@ -52,20 +60,36 @@ def new_torch_lerp(input, end, weight, *, out=None):
else:
return _torch_lerp(input, end, weight, out=out)
torch.lerp = new_torch_lerp
_torch_interpolate = torch.nn.functional.interpolate
def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
def new_torch_interpolate(
input,
size=None,
scale_factor=None,
mode="nearest",
align_corners=None,
recompute_scale_factor=None,
antialias=False,
):
if input.device.type == "mps" and input.dtype == torch.float16:
return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half()
return _torch_interpolate(
input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
).half()
else:
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
torch.nn.functional.interpolate = new_torch_interpolate
# TODO: refactor it
_SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor
class ChunkedSlicedAttnProcessor:
r"""
Processor for implementing sliced attention.
@ -78,7 +102,7 @@ class ChunkedSlicedAttnProcessor:
def __init__(self, slice_size):
assert isinstance(slice_size, int)
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
self.slice_size = slice_size
self._sliced_attn_processor = _SlicedAttnProcessor(slice_size)
@ -121,7 +145,9 @@ class ChunkedSlicedAttnProcessor:
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
chunk_tmp_tensor = torch.empty(self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device)
chunk_tmp_tensor = torch.empty(
self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
@ -131,7 +157,15 @@ class ChunkedSlicedAttnProcessor:
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
self.get_attention_scores_chunked(attn, query_slice, key_slice, attn_mask_slice, hidden_states[start_idx:end_idx], value[start_idx:end_idx], chunk_tmp_tensor)
self.get_attention_scores_chunked(
attn,
query_slice,
key_slice,
attn_mask_slice,
hidden_states[start_idx:end_idx],
value[start_idx:end_idx],
chunk_tmp_tensor,
)
hidden_states = attn.batch_to_head_dim(hidden_states)
@ -150,7 +184,6 @@ class ChunkedSlicedAttnProcessor:
return hidden_states
def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk):
# batch size = 1
assert query.shape[0] == 1
@ -163,14 +196,14 @@ class ChunkedSlicedAttnProcessor:
query = query.float()
key = key.float()
#out_item_size = query.dtype.itemsize
#if attn.upcast_attention:
# out_item_size = query.dtype.itemsize
# if attn.upcast_attention:
# out_item_size = torch.float32.itemsize
out_item_size = query.element_size()
if attn.upcast_attention:
out_item_size = 4
chunk_size = 2 ** 29
chunk_size = 2**29
out_size = query.shape[1] * key.shape[1] * out_item_size
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
@ -181,8 +214,8 @@ class ChunkedSlicedAttnProcessor:
def _get_chunk_view(tensor, start, length):
if start + length > tensor.shape[1]:
length = tensor.shape[1] - start
#print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
return tensor[:,start:start+length]
# print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
return tensor[:, start : start + length]
for chunk_pos in range(0, query.shape[1], chunk_step):
if attention_mask is not None:
@ -196,7 +229,7 @@ class ChunkedSlicedAttnProcessor:
)
else:
torch.baddbmm(
torch.zeros((1,1,1), device=query.device, dtype=query.dtype),
torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype),
_get_chunk_view(query, chunk_pos, chunk_step),
key,
beta=0,
@ -206,7 +239,7 @@ class ChunkedSlicedAttnProcessor:
chunk = chunk.softmax(dim=-1)
torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step))
#del chunk
# del chunk
diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor

View File

@ -32,9 +32,7 @@ def log_txt_as_img(wh, xc, size=10):
draw = ImageDraw.Draw(txt)
font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256))
lines = "\n".join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
@ -81,9 +79,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
logger.debug(
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
)
logger.debug(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
@ -154,21 +150,12 @@ def parallel_data_prefetch(
proc = Thread
# spawn processes
if target_data_type == "ndarray":
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
]
arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))]
else:
step = (
int(len(data) / n_proc + 1)
if len(data) % n_proc != 0
else int(len(data) / n_proc)
)
step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc)
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(
[data[i : i + step] for i in range(0, len(data), step)]
)
for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)])
]
processes = []
for i in range(n_proc):
@ -220,9 +207,7 @@ def parallel_data_prefetch(
return gather_res
def rand_perlin_2d(
shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
):
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
@ -265,9 +250,9 @@ def rand_perlin_2d(
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device)
t = fade(grid[: shape[0], : shape[1]])
noise = math.sqrt(2) * torch.lerp(
torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
).to(device)
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(
device
)
return noise.to(dtype=torch_dtype(device))
@ -276,9 +261,7 @@ def ask_user(question: str, answers: list):
user_prompt = f"\n>> {question} {answers}: "
invalid_answer_msg = "Invalid answer. Please try again."
pose_question = chain(
[user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt]))
)
pose_question = chain([user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt])))
user_answers = map(input, pose_question)
valid_response = next(filter(answers.__contains__, user_answers))
return valid_response
@ -303,9 +286,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
if dest.is_dir():
try:
file_name = re.search(
'filename="(.+)"', resp.headers.get("Content-Disposition")
).group(1)
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
except:
file_name = os.path.basename(url)
dest = dest / file_name
@ -322,7 +303,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
logger.warning("corrupt existing file found. re-downloading")
os.remove(dest)
exist_size = 0
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
logger.warning(f"{dest}: complete file found. Skipping.")
return dest
@ -377,16 +358,16 @@ def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
buffered = io.BytesIO()
image.save(buffered, format=image_format)
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
image_base64 = f"data:{mime_type};base64," + base64.b64encode(buffered.getvalue()).decode("UTF-8")
return image_base64
class Chdir(object):
'''Context manager to chdir to desired directory and change back after context exits:
"""Context manager to chdir to desired directory and change back after context exits:
Args:
path (Path): The path to the cwd
'''
"""
def __init__(self, path: Path):
self.path = path
self.original = Path().absolute()
@ -394,5 +375,5 @@ class Chdir(object):
def __enter__(self):
os.chdir(self.path)
def __exit__(self,*args):
def __exit__(self, *args):
os.chdir(self.original)

View File

@ -64,10 +64,7 @@ class InvokeAIWebServer:
self.ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg"}
def allowed_file(self, filename: str) -> bool:
return (
"." in filename
and filename.rsplit(".", 1)[1].lower() in self.ALLOWED_EXTENSIONS
)
return "." in filename and filename.rsplit(".", 1)[1].lower() in self.ALLOWED_EXTENSIONS
def run(self):
self.setup_app()
@ -99,9 +96,7 @@ class InvokeAIWebServer:
_cors = _cors.split(",")
socketio_args["cors_allowed_origins"] = _cors
self.app = Flask(
__name__, static_url_path="", static_folder=frontend.__path__[0]
)
self.app = Flask(__name__, static_url_path="", static_folder=frontend.__path__[0])
self.socketio = SocketIO(self.app, **socketio_args)
@ -192,9 +187,7 @@ class InvokeAIWebServer:
(width, height) = pil_image.size
thumbnail_path = save_thumbnail(
pil_image, os.path.basename(file_path), self.thumbnail_image_path
)
thumbnail_path = save_thumbnail(pil_image, os.path.basename(file_path), self.thumbnail_image_path)
response = {
"url": self.get_url_from_image_path(file_path),
@ -237,12 +230,8 @@ class InvokeAIWebServer:
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
)
else:
logger.info(
"Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
)
logger.info(
f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
)
logger.info("Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
logger.info(f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}")
if not useSSL:
self.socketio.run(app=self.app, host=self.host, port=self.port)
else:
@ -392,22 +381,16 @@ class InvokeAIWebServer:
@socketio.on("convertToDiffusers")
def convert_to_diffusers(model_to_convert: dict):
try:
if model_info := self.generate.model_manager.model_info(
model_name=model_to_convert["model_name"]
):
if model_info := self.generate.model_manager.model_info(model_name=model_to_convert["model_name"]):
if "weights" in model_info:
ckpt_path = Path(model_info["weights"])
original_config_file = Path(model_info["config"])
model_name = model_to_convert["model_name"]
model_description = model_info["description"]
else:
self.socketio.emit(
"error", {"message": "Model is not a valid checkpoint file"}
)
self.socketio.emit("error", {"message": "Model is not a valid checkpoint file"})
else:
self.socketio.emit(
"error", {"message": "Could not retrieve model info."}
)
self.socketio.emit("error", {"message": "Could not retrieve model info."})
if not ckpt_path.is_absolute():
ckpt_path = Path(Globals.root, ckpt_path)
@ -415,22 +398,13 @@ class InvokeAIWebServer:
if original_config_file and not original_config_file.is_absolute():
original_config_file = Path(Globals.root, original_config_file)
diffusers_path = Path(
ckpt_path.parent.absolute(), f"{model_name}_diffusers"
)
diffusers_path = Path(ckpt_path.parent.absolute(), f"{model_name}_diffusers")
if model_to_convert["save_location"] == "root":
diffusers_path = Path(
global_converted_ckpts_dir(), f"{model_name}_diffusers"
)
diffusers_path = Path(global_converted_ckpts_dir(), f"{model_name}_diffusers")
if (
model_to_convert["save_location"] == "custom"
and model_to_convert["custom_location"] is not None
):
diffusers_path = Path(
model_to_convert["custom_location"], f"{model_name}_diffusers"
)
if model_to_convert["save_location"] == "custom" and model_to_convert["custom_location"] is not None:
diffusers_path = Path(model_to_convert["custom_location"], f"{model_name}_diffusers")
if diffusers_path.exists():
shutil.rmtree(diffusers_path)
@ -462,10 +436,7 @@ class InvokeAIWebServer:
def merge_diffusers_models(model_merge_info: dict):
try:
models_to_merge = model_merge_info["models_to_merge"]
model_ids_or_paths = [
self.generate.model_manager.model_name_or_path(x)
for x in models_to_merge
]
model_ids_or_paths = [self.generate.model_manager.model_name_or_path(x) for x in models_to_merge]
merged_pipe = merge_diffusion_models(
model_ids_or_paths,
model_merge_info["alpha"],
@ -487,15 +458,11 @@ class InvokeAIWebServer:
commit_to_conf=opt.conf,
)
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
"vae", None
):
if vae := self.generate.model_manager.config[models_to_merge[0]].get("vae", None):
logger.info(f"Using configured VAE assigned to {models_to_merge[0]}")
merged_model_config.update(vae=vae)
self.generate.model_manager.import_diffuser_model(
dump_path, **merged_model_config
)
self.generate.model_manager.import_diffuser_model(dump_path, **merged_model_config)
new_model_list = self.generate.model_manager.list_models()
socketio.emit(
@ -525,9 +492,7 @@ class InvokeAIWebServer:
)
os.remove(thumbnail_path)
except Exception as e:
socketio.emit(
"error", {"message": f"Unable to delete {f}: {str(e)}"}
)
socketio.emit("error", {"message": f"Unable to delete {f}: {str(e)}"})
pass
socketio.emit("tempFolderEmptied")
@ -550,9 +515,7 @@ class InvokeAIWebServer:
(width, height) = pil_image.size
thumbnail_path = save_thumbnail(
pil_image, os.path.basename(new_path), self.thumbnail_image_path
)
thumbnail_path = save_thumbnail(pil_image, os.path.basename(new_path), self.thumbnail_image_path)
image_array = [
{
@ -577,18 +540,14 @@ class InvokeAIWebServer:
@socketio.on("requestLatestImages")
def handle_request_latest_images(category, latest_mtime):
try:
base_path = (
self.result_path if category == "result" else self.init_image_path
)
base_path = self.result_path if category == "result" else self.init_image_path
paths = []
for ext in ("*.png", "*.jpg", "*.jpeg"):
paths.extend(glob.glob(os.path.join(base_path, ext)))
image_paths = sorted(
paths, key=lambda x: os.path.getmtime(x), reverse=True
)
image_paths = sorted(paths, key=lambda x: os.path.getmtime(x), reverse=True)
image_paths = list(
filter(
@ -609,16 +568,12 @@ class InvokeAIWebServer:
pil_image = Image.open(path)
(width, height) = pil_image.size
thumbnail_path = save_thumbnail(
pil_image, os.path.basename(path), self.thumbnail_image_path
)
thumbnail_path = save_thumbnail(pil_image, os.path.basename(path), self.thumbnail_image_path)
image_array.append(
{
"url": self.get_url_from_image_path(path),
"thumbnail": self.get_url_from_image_path(
thumbnail_path
),
"thumbnail": self.get_url_from_image_path(thumbnail_path),
"mtime": os.path.getmtime(path),
"metadata": metadata.get("sd-metadata"),
"dreamPrompt": metadata.get("Dream"),
@ -628,9 +583,7 @@ class InvokeAIWebServer:
}
)
except Exception as e:
socketio.emit(
"error", {"message": f"Unable to load {path}: {str(e)}"}
)
socketio.emit("error", {"message": f"Unable to load {path}: {str(e)}"})
pass
socketio.emit(
@ -645,17 +598,13 @@ class InvokeAIWebServer:
try:
page_size = 50
base_path = (
self.result_path if category == "result" else self.init_image_path
)
base_path = self.result_path if category == "result" else self.init_image_path
paths = []
for ext in ("*.png", "*.jpg", "*.jpeg"):
paths.extend(glob.glob(os.path.join(base_path, ext)))
image_paths = sorted(
paths, key=lambda x: os.path.getmtime(x), reverse=True
)
image_paths = sorted(paths, key=lambda x: os.path.getmtime(x), reverse=True)
if earliest_mtime:
image_paths = list(
@ -679,16 +628,12 @@ class InvokeAIWebServer:
pil_image = Image.open(path)
(width, height) = pil_image.size
thumbnail_path = save_thumbnail(
pil_image, os.path.basename(path), self.thumbnail_image_path
)
thumbnail_path = save_thumbnail(pil_image, os.path.basename(path), self.thumbnail_image_path)
image_array.append(
{
"url": self.get_url_from_image_path(path),
"thumbnail": self.get_url_from_image_path(
thumbnail_path
),
"thumbnail": self.get_url_from_image_path(thumbnail_path),
"mtime": os.path.getmtime(path),
"metadata": metadata.get("sd-metadata"),
"dreamPrompt": metadata.get("Dream"),
@ -699,9 +644,7 @@ class InvokeAIWebServer:
)
except Exception as e:
logger.info(f"Unable to load {path}")
socketio.emit(
"error", {"message": f"Unable to load {path}: {str(e)}"}
)
socketio.emit("error", {"message": f"Unable to load {path}: {str(e)}"})
pass
socketio.emit(
@ -716,9 +659,7 @@ class InvokeAIWebServer:
self.handle_exceptions(e)
@socketio.on("generateImage")
def handle_generate_image_event(
generation_parameters, esrgan_parameters, facetool_parameters
):
def handle_generate_image_event(generation_parameters, esrgan_parameters, facetool_parameters):
try:
# truncate long init_mask/init_img base64 if needed
printable_parameters = {
@ -726,14 +667,10 @@ class InvokeAIWebServer:
}
if "init_img" in generation_parameters:
printable_parameters["init_img"] = (
printable_parameters["init_img"][:64] + "..."
)
printable_parameters["init_img"] = printable_parameters["init_img"][:64] + "..."
if "init_mask" in generation_parameters:
printable_parameters["init_mask"] = (
printable_parameters["init_mask"][:64] + "..."
)
printable_parameters["init_mask"] = printable_parameters["init_mask"][:64] + "..."
logger.info(f"Image Generation Parameters:\n\n{printable_parameters}\n")
logger.info(f"ESRGAN Parameters: {esrgan_parameters}")
@ -750,18 +687,14 @@ class InvokeAIWebServer:
@socketio.on("runPostprocessing")
def handle_run_postprocessing(original_image, postprocessing_parameters):
try:
logger.info(
f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
)
logger.info(f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}')
progress = Progress()
socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
original_image_path = self.get_image_path_from_url(
original_image["url"]
)
original_image_path = self.get_image_path_from_url(original_image["url"])
image = Image.open(original_image_path)
@ -801,14 +734,10 @@ class InvokeAIWebServer:
strength=postprocessing_parameters["facetool_strength"],
fidelity=postprocessing_parameters["codeformer_fidelity"],
seed=seed,
device="cpu"
if str(self.generate.device) == "mps"
else self.generate.device,
device="cpu" if str(self.generate.device) == "mps" else self.generate.device,
)
else:
raise TypeError(
f'{postprocessing_parameters["type"]} is not a valid postprocessing type'
)
raise TypeError(f'{postprocessing_parameters["type"]} is not a valid postprocessing type')
progress.set_current_status("common.statusSavingImage")
socketio.emit("progressUpdate", progress.to_formatted_dict())
@ -832,9 +761,7 @@ class InvokeAIWebServer:
postprocessing=postprocessing_parameters["type"],
)
thumbnail_path = save_thumbnail(
image, os.path.basename(path), self.thumbnail_image_path
)
thumbnail_path = save_thumbnail(image, os.path.basename(path), self.thumbnail_image_path)
self.write_log_message(
f'[Postprocessed] "{original_image_path}" > "{path}": {postprocessing_parameters}'
@ -901,17 +828,13 @@ class InvokeAIWebServer:
"app_version": APP_VERSION,
}
def generate_images(
self, generation_parameters, esrgan_parameters, facetool_parameters
):
def generate_images(self, generation_parameters, esrgan_parameters, facetool_parameters):
try:
self.canceled.clear()
step_index = 1
prior_variations = (
generation_parameters["with_variations"]
if "with_variations" in generation_parameters
else []
generation_parameters["with_variations"] if "with_variations" in generation_parameters else []
)
actual_generation_mode = generation_parameters["generation_mode"]
@ -943,9 +866,7 @@ class InvokeAIWebServer:
original_bounding_box = generation_parameters["bounding_box"].copy()
initial_image = dataURL_to_image(
generation_parameters["init_img"]
).convert("RGBA")
initial_image = dataURL_to_image(generation_parameters["init_img"]).convert("RGBA")
"""
The outpaint image and mask are pre-cropped by the UI, so the bounding box we pass
@ -962,13 +883,9 @@ class InvokeAIWebServer:
generation_parameters["bounding_box"]["y"] = 0
# Convert mask dataURL to an image and convert to greyscale
mask_image = dataURL_to_image(
generation_parameters["init_mask"]
).convert("L")
mask_image = dataURL_to_image(generation_parameters["init_mask"]).convert("L")
actual_generation_mode = get_canvas_generation_mode(
initial_image, mask_image
)
actual_generation_mode = get_canvas_generation_mode(initial_image, mask_image)
"""
Apply the mask to the init image, creating a "mask" image with
@ -1018,9 +935,7 @@ class InvokeAIWebServer:
elif generation_parameters["generation_mode"] == "img2img":
init_img_url = generation_parameters["init_img"]
init_img_path = self.get_image_path_from_url(init_img_url)
generation_parameters["init_img"] = Image.open(init_img_path).convert(
"RGB"
)
generation_parameters["init_img"] = Image.open(init_img_path).convert("RGB")
def image_progress(intermediate_state: PipelineIntermediateState):
if self.canceled.is_set():
@ -1046,9 +961,7 @@ class InvokeAIWebServer:
}
progress.set_current_step(step + 1)
progress.set_current_status(
f"{generation_messages[actual_generation_mode]}"
)
progress.set_current_status(f"{generation_messages[actual_generation_mode]}")
progress.set_current_status_has_steps(True)
if (
@ -1057,9 +970,7 @@ class InvokeAIWebServer:
and step < generation_parameters["steps"] - 1
):
image = self.generate.sample_to_image(sample)
metadata = self.parameters_to_generated_image_metadata(
generation_parameters
)
metadata = self.parameters_to_generated_image_metadata(generation_parameters)
command = parameters_to_command(generation_parameters)
(width, height) = image.size
@ -1140,15 +1051,10 @@ class InvokeAIWebServer:
all_parameters = generation_parameters
postprocessing = False
if (
"variation_amount" in all_parameters
and all_parameters["variation_amount"] > 0
):
if "variation_amount" in all_parameters and all_parameters["variation_amount"] > 0:
first_seed = first_seed or seed
this_variation = [[seed, all_parameters["variation_amount"]]]
all_parameters["with_variations"] = (
prior_variations + this_variation
)
all_parameters["with_variations"] = prior_variations + this_variation
all_parameters["seed"] = first_seed
elif "with_variations" in all_parameters:
all_parameters["seed"] = first_seed
@ -1186,9 +1092,7 @@ class InvokeAIWebServer:
if facetool_parameters["type"] == "gfpgan":
progress.set_current_status("common.statusRestoringFacesGFPGAN")
elif facetool_parameters["type"] == "codeformer":
progress.set_current_status(
"common.statusRestoringFacesCodeFormer"
)
progress.set_current_status("common.statusRestoringFacesCodeFormer")
progress.set_current_status_has_steps(False)
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
@ -1206,18 +1110,12 @@ class InvokeAIWebServer:
strength=facetool_parameters["strength"],
fidelity=facetool_parameters["codeformer_fidelity"],
seed=seed,
device="cpu"
if str(self.generate.device) == "mps"
else self.generate.device,
device="cpu" if str(self.generate.device) == "mps" else self.generate.device,
)
all_parameters["codeformer_fidelity"] = facetool_parameters[
"codeformer_fidelity"
]
all_parameters["codeformer_fidelity"] = facetool_parameters["codeformer_fidelity"]
postprocessing = True
all_parameters["facetool_strength"] = facetool_parameters[
"strength"
]
all_parameters["facetool_strength"] = facetool_parameters["strength"]
all_parameters["facetool_type"] = facetool_parameters["type"]
progress.set_current_status("common.statusSavingImage")
@ -1226,9 +1124,7 @@ class InvokeAIWebServer:
# restore the stashed URLS and discard the paths, we are about to send the result to client
all_parameters["init_img"] = (
init_img_url
if generation_parameters["generation_mode"] == "img2img"
else ""
init_img_url if generation_parameters["generation_mode"] == "img2img" else ""
)
if "init_mask" in all_parameters:
@ -1246,8 +1142,7 @@ class InvokeAIWebServer:
generated_image_outdir = (
self.result_path
if generation_parameters["generation_mode"]
in ["txt2img", "img2img"]
if generation_parameters["generation_mode"] in ["txt2img", "img2img"]
else self.temp_image_path
)
@ -1259,9 +1154,7 @@ class InvokeAIWebServer:
postprocessing=postprocessing,
)
thumbnail_path = save_thumbnail(
image, os.path.basename(path), self.thumbnail_image_path
)
thumbnail_path = save_thumbnail(image, os.path.basename(path), self.thumbnail_image_path)
logger.info(f'Image generated: "{path}"\n')
self.write_log_message(f'[Generated] "{path}": {command}')
@ -1281,14 +1174,10 @@ class InvokeAIWebServer:
tokens = (
None
if type(parsed_prompt) is Blend
else get_tokens_for_prompt_object(
model.tokenizer, parsed_prompt
)
else get_tokens_for_prompt_object(model.tokenizer, parsed_prompt)
)
attention_maps_image_base64_url = (
None
if attention_maps_image is None
else image_to_dataURL(attention_maps_image)
None if attention_maps_image is None else image_to_dataURL(attention_maps_image)
)
self.socketio.emit(
@ -1382,9 +1271,7 @@ class InvokeAIWebServer:
}
if parameters["facetool_type"] == "codeformer":
facetool_parameters["fidelity"] = float(
parameters["codeformer_fidelity"]
)
facetool_parameters["fidelity"] = float(parameters["codeformer_fidelity"])
postprocessing.append(facetool_parameters)
@ -1398,9 +1285,7 @@ class InvokeAIWebServer:
}
)
rfc_dict["postprocessing"] = (
postprocessing if len(postprocessing) > 0 else None
)
rfc_dict["postprocessing"] = postprocessing if len(postprocessing) > 0 else None
# semantic drift
rfc_dict["sampler"] = parameters["sampler_name"]
@ -1409,22 +1294,15 @@ class InvokeAIWebServer:
variations = []
if "with_variations" in parameters:
variations = [
{"seed": x[0], "weight": x[1]}
for x in parameters["with_variations"]
]
variations = [{"seed": x[0], "weight": x[1]} for x in parameters["with_variations"]]
rfc_dict["variations"] = variations
if rfc_dict["type"] == "img2img":
rfc_dict["strength"] = parameters["strength"]
rfc_dict["fit"] = parameters["fit"] # TODO: Noncompliant
rfc_dict["orig_hash"] = calculate_init_img_hash(
self.get_image_path_from_url(parameters["init_img"])
)
rfc_dict["init_image_path"] = parameters[
"init_img"
] # TODO: Noncompliant
rfc_dict["orig_hash"] = calculate_init_img_hash(self.get_image_path_from_url(parameters["init_img"]))
rfc_dict["init_image_path"] = parameters["init_img"] # TODO: Noncompliant
metadata["image"] = rfc_dict
@ -1433,9 +1311,7 @@ class InvokeAIWebServer:
except Exception as e:
self.handle_exceptions(e)
def parameters_to_post_processed_image_metadata(
self, parameters, original_image_path
):
def parameters_to_post_processed_image_metadata(self, parameters, original_image_path):
try:
current_metadata = retrieve_metadata(original_image_path)["sd-metadata"]
postprocessing_metadata = {}
@ -1447,9 +1323,7 @@ class InvokeAIWebServer:
if "image" not in current_metadata:
current_metadata["image"] = {}
orig_hash = calculate_init_img_hash(
self.get_image_path_from_url(original_image_path)
)
orig_hash = calculate_init_img_hash(self.get_image_path_from_url(original_image_path))
postprocessing_metadata["orig_path"] = (original_image_path,)
postprocessing_metadata["orig_hash"] = orig_hash
@ -1473,9 +1347,7 @@ class InvokeAIWebServer:
if "postprocessing" in current_metadata["image"] and isinstance(
current_metadata["image"]["postprocessing"], list
):
current_metadata["image"]["postprocessing"].append(
postprocessing_metadata
)
current_metadata["image"]["postprocessing"].append(postprocessing_metadata)
else:
current_metadata["image"]["postprocessing"] = [postprocessing_metadata]
@ -1556,29 +1428,17 @@ class InvokeAIWebServer:
"""Given a url to an image used by the client, returns the absolute file path to that image"""
try:
if "init-images" in url:
return os.path.abspath(
os.path.join(self.init_image_path, os.path.basename(url))
)
return os.path.abspath(os.path.join(self.init_image_path, os.path.basename(url)))
elif "mask-images" in url:
return os.path.abspath(
os.path.join(self.mask_image_path, os.path.basename(url))
)
return os.path.abspath(os.path.join(self.mask_image_path, os.path.basename(url)))
elif "intermediates" in url:
return os.path.abspath(
os.path.join(self.intermediate_path, os.path.basename(url))
)
return os.path.abspath(os.path.join(self.intermediate_path, os.path.basename(url)))
elif "temp-images" in url:
return os.path.abspath(
os.path.join(self.temp_image_path, os.path.basename(url))
)
return os.path.abspath(os.path.join(self.temp_image_path, os.path.basename(url)))
elif "thumbnails" in url:
return os.path.abspath(
os.path.join(self.thumbnail_image_path, os.path.basename(url))
)
return os.path.abspath(os.path.join(self.thumbnail_image_path, os.path.basename(url)))
else:
return os.path.abspath(
os.path.join(self.result_path, os.path.basename(url))
)
return os.path.abspath(os.path.join(self.result_path, os.path.basename(url)))
except Exception as e:
self.handle_exceptions(e)
@ -1632,18 +1492,14 @@ class Progress:
self.total_steps = (
self._calculate_real_steps(
steps=generation_parameters["steps"],
strength=generation_parameters["strength"]
if "strength" in generation_parameters
else None,
strength=generation_parameters["strength"] if "strength" in generation_parameters else None,
has_init_image="init_img" in generation_parameters,
)
if generation_parameters
else 1
)
self.current_iteration = 1
self.total_iterations = (
generation_parameters["iterations"] if generation_parameters else 1
)
self.total_iterations = generation_parameters["iterations"] if generation_parameters else 1
self.current_status = "common.statusPreparing"
self.is_processing = True
self.current_status_has_steps = False
@ -1703,9 +1559,7 @@ class CanceledException(Exception):
pass
def copy_image_from_bounding_box(
image: ImageType, x: int, y: int, width: int, height: int
) -> ImageType:
def copy_image_from_bounding_box(image: ImageType, x: int, y: int, width: int, height: int) -> ImageType:
"""
Returns a copy an image, cropped to a bounding box.
"""
@ -1740,9 +1594,7 @@ def image_to_dataURL(image: ImageType, image_format: str = "PNG") -> str:
buffered = io.BytesIO()
image.save(buffered, format=image_format)
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
image_base64 = f"data:{mime_type};base64," + base64.b64encode(buffered.getvalue()).decode("UTF-8")
return image_base64

View File

@ -40,9 +40,7 @@ def get_canvas_generation_mode(
init_img_has_transparency = check_for_any_transparency(init_img)
if init_img_has_transparency:
init_img_is_fully_transparent = (
True if init_img_alpha_mask.getbbox() is None else False
)
init_img_is_fully_transparent = True if init_img_alpha_mask.getbbox() is None else False
"""
Mask images are white in areas where no change should be made, black where changes

View File

@ -10,7 +10,7 @@ SAMPLER_CHOICES = [
"lms_k",
"pndm",
"heun",
'heun_k',
"heun_k",
"euler",
"euler_k",
"euler_a",
@ -76,9 +76,7 @@ def parameters_to_command(params):
if "variation_amount" in params and params["variation_amount"] > 0:
switches.append(f'-v {params["variation_amount"]}')
if "with_variations" in params:
seed_weight_pairs = ",".join(
f"{seed}:{weight}" for seed, weight in params["with_variations"]
)
seed_weight_pairs = ",".join(f"{seed}:{weight}" for seed, weight in params["with_variations"])
switches.append(f"-V {seed_weight_pairs}")
return " ".join(switches)