Merge branch 'main' into ryan/multi-image-ip

This commit is contained in:
Ryan Dick
2023-10-18 08:59:12 -04:00
125 changed files with 3375 additions and 4664 deletions

View File

@ -20,12 +20,12 @@ class InvisibleWatermark:
"""
@classmethod
def invisible_watermark_available(self) -> bool:
def invisible_watermark_available(cls) -> bool:
return config.invisible_watermark
@classmethod
def add_watermark(self, image: Image, watermark_text: str) -> Image:
if not self.invisible_watermark_available():
def add_watermark(cls, image: Image.Image, watermark_text: str) -> Image.Image:
if not cls.invisible_watermark_available():
return image
logger.debug(f'Applying invisible watermark "{watermark_text}"')
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)

View File

@ -26,8 +26,8 @@ class SafetyChecker:
tried_load: bool = False
@classmethod
def _load_safety_checker(self):
if self.tried_load:
def _load_safety_checker(cls):
if cls.tried_load:
return
if config.nsfw_checker:
@ -35,31 +35,31 @@ class SafetyChecker:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
logger.info("NSFW checker initialized")
except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}")
else:
logger.info("NSFW checker loading disabled")
self.tried_load = True
cls.tried_load = True
@classmethod
def safety_checker_available(self) -> bool:
self._load_safety_checker()
return self.safety_checker is not None
def safety_checker_available(cls) -> bool:
cls._load_safety_checker()
return cls.safety_checker is not None
@classmethod
def has_nsfw_concept(self, image: Image) -> bool:
if not self.safety_checker_available():
def has_nsfw_concept(cls, image: Image.Image) -> bool:
if not cls.safety_checker_available():
return False
device = choose_torch_device()
features = self.feature_extractor([image], return_tensors="pt")
features = cls.feature_extractor([image], return_tensors="pt")
features.to(device)
self.safety_checker.to(device)
cls.safety_checker.to(device)
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
with SilenceWarnings():
checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values)
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
return has_nsfw_concept[0]

View File

@ -41,18 +41,18 @@ config = InvokeAIAppConfig.get_config()
class SegmentedGrayscale(object):
def __init__(self, image: Image, heatmap: torch.Tensor):
def __init__(self, image: Image.Image, heatmap: torch.Tensor):
self.heatmap = heatmap
self.image = image
def to_grayscale(self, invert: bool = False) -> Image:
def to_grayscale(self, invert: bool = False) -> Image.Image:
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
def to_mask(self, threshold: float = 0.5) -> Image:
def to_mask(self, threshold: float = 0.5) -> Image.Image:
discrete_heatmap = self.heatmap.lt(threshold).int()
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L"))
def to_transparent(self, invert: bool = False) -> Image:
def to_transparent(self, invert: bool = False) -> Image.Image:
transparent_image = self.image.copy()
# For img2img, we want the selected regions to be transparent,
# but to_grayscale() returns the opposite. Thus invert.
@ -61,7 +61,7 @@ class SegmentedGrayscale(object):
return transparent_image
# unscales and uncrops the 352x352 heatmap so that it matches the image again
def _rescale(self, heatmap: Image) -> Image:
def _rescale(self, heatmap: Image.Image) -> Image.Image:
size = self.image.width if (self.image.width > self.image.height) else self.image.height
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
return resized_image.crop((0, 0, self.image.width, self.image.height))
@ -82,7 +82,7 @@ class Txt2Mask(object):
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
@torch.no_grad()
def segment(self, image, prompt: str) -> SegmentedGrayscale:
def segment(self, image: Image.Image, prompt: str) -> SegmentedGrayscale:
"""
Given a prompt string such as "a bagel", tries to identify the object in the
provided image and returns a SegmentedGrayscale object in which the brighter
@ -99,7 +99,7 @@ class Txt2Mask(object):
heatmap = torch.sigmoid(outputs.logits)
return SegmentedGrayscale(image, heatmap)
def _scale_and_crop(self, image: Image) -> Image:
def _scale_and_crop(self, image: Image.Image) -> Image.Image:
scaled_image = Image.new("RGB", (CLIPSEG_SIZE, CLIPSEG_SIZE))
if image.width > image.height: # width is constraint
scale = CLIPSEG_SIZE / image.width

View File

@ -9,7 +9,7 @@ class InitImageResizer:
def __init__(self, Image):
self.image = Image
def resize(self, width=None, height=None) -> Image:
def resize(self, width=None, height=None) -> Image.Image:
"""
Return a copy of the image resized to fit within
a box width x height. The aspect ratio is

View File

@ -793,7 +793,11 @@ def migrate_init_file(legacy_format: Path):
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
new = InvokeAIAppConfig.get_config()
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"]
fields = [
x
for x, y in InvokeAIAppConfig.model_fields.items()
if (y.json_schema_extra.get("category", None) if y.json_schema_extra else None) != "DEPRECATED"
]
for attr in fields:
if hasattr(old, attr):
try:

View File

@ -236,13 +236,13 @@ import types
from dataclasses import dataclass
from pathlib import Path
from shutil import move, rmtree
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union, cast
import torch
import yaml
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
@ -294,6 +294,8 @@ class AddModelResult(BaseModel):
base_model: BaseModelType = Field(description="The base model")
config: ModelConfigBase = Field(description="The configuration of the model")
model_config = ConfigDict(protected_namespaces=())
MAX_CACHE_SIZE = 6.0 # GB
@ -576,7 +578,7 @@ class ModelManager(object):
"""
model_key = self.create_key(model_name, base_model, model_type)
if model_key in self.models:
return self.models[model_key].dict(exclude_defaults=True)
return self.models[model_key].model_dump(exclude_defaults=True)
else:
return None # TODO: None or empty dict on not found
@ -632,7 +634,7 @@ class ModelManager(object):
continue
model_dict = dict(
**model_config.dict(exclude_defaults=True),
**model_config.model_dump(exclude_defaults=True),
# OpenAPIModelInfoBase
model_name=cur_model_name,
base_model=cur_base_model,
@ -900,14 +902,16 @@ class ModelManager(object):
Write current configuration out to the indicated file.
"""
data_to_save = dict()
data_to_save["__metadata__"] = self.config_meta.dict()
data_to_save["__metadata__"] = self.config_meta.model_dump()
for model_key, model_config in self.models.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = self._get_implementation(base_model, model_type)
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
data_to_save[model_key] = cast(BaseModel, model_config).model_dump(
exclude_defaults=True, exclude={"error"}, mode="json"
)
# alias for config file
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")

View File

@ -2,7 +2,7 @@ import inspect
from enum import Enum
from typing import Literal, get_origin
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, create_model
from .base import ( # noqa: F401
BaseModelType,
@ -106,6 +106,8 @@ class OpenAPIModelInfoBase(BaseModel):
base_model: BaseModelType
model_type: ModelType
model_config = ConfigDict(protected_namespaces=())
for base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
@ -121,17 +123,11 @@ for base_model, models in MODEL_CLASSES.items():
if openapi_cfg_name in vars():
continue
api_wrapper = type(
api_wrapper = create_model(
openapi_cfg_name,
(cfg, OpenAPIModelInfoBase),
dict(
__annotations__=dict(
model_type=Literal[model_type.value],
),
),
__base__=(cfg, OpenAPIModelInfoBase),
model_type=(Literal[model_type], model_type), # type: ignore
)
# globals()[openapi_cfg_name] = api_wrapper
vars()[openapi_cfg_name] = api_wrapper
OPENAPI_MODEL_CONFIGS.append(api_wrapper)

View File

@ -19,7 +19,7 @@ from diffusers import logging as diffusers_logging
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from transformers import logging as transformers_logging
@ -86,14 +86,21 @@ class ModelError(str, Enum):
NotFound = "not_found"
def model_config_json_schema_extra(schema: dict[str, Any]) -> None:
if "required" not in schema:
schema["required"] = []
schema["required"].append("model_type")
class ModelConfigBase(BaseModel):
path: str # or Path
description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
error: Optional[ModelError] = Field(None)
class Config:
use_enum_values = True
model_config = ConfigDict(
use_enum_values=True, protected_namespaces=(), json_schema_extra=model_config_json_schema_extra
)
class EmptyConfigLoader(ConfigMixin):

View File

@ -58,14 +58,16 @@ class IPAdapterModel(ModelBase):
def get_model(
self,
torch_dtype: Optional[torch.dtype],
torch_dtype: torch.dtype,
child_type: Optional[SubModelType] = None,
) -> typing.Union[IPAdapter, IPAdapterPlus]:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
model = build_ip_adapter(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"),
device=torch.device("cpu"),
dtype=torch_dtype,
)
self.model_size = model.calc_size()

View File

@ -96,7 +96,7 @@ def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axe
finally:
for module, orig_conv_forward in to_restore:
module._conv_forward = orig_conv_forward
if hasattr(m, "asymmetric_padding_mode"):
del m.asymmetric_padding_mode
if hasattr(m, "asymmetric_padding"):
del m.asymmetric_padding
if hasattr(module, "asymmetric_padding_mode"):
del module.asymmetric_padding_mode
if hasattr(module, "asymmetric_padding"):
del module.asymmetric_padding

View File

@ -1,7 +1,8 @@
import math
from typing import Optional
import PIL
import torch
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.functional import resize as tv_resize
@ -11,7 +12,7 @@ class AttentionMapSaver:
self.token_ids = token_ids
self.latents_shape = latents_shape
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
self.collated_maps = {}
self.collated_maps: dict[str, torch.Tensor] = {}
def clear_maps(self):
self.collated_maps = {}
@ -38,9 +39,10 @@ class AttentionMapSaver:
def write_maps_to_disk(self, path: str):
pil_image = self.get_stacked_maps_image()
pil_image.save(path, "PNG")
if pil_image is not None:
pil_image.save(path, "PNG")
def get_stacked_maps_image(self) -> PIL.Image:
def get_stacked_maps_image(self) -> Optional[Image.Image]:
"""
Scale all collected attention maps to the same size, blend them together and return as an image.
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
@ -95,4 +97,4 @@ class AttentionMapSaver:
return None
merged_bytes = merged.mul(0xFF).byte()
return PIL.Image.fromarray(merged_bytes.numpy(), mode="L")
return Image.fromarray(merged_bytes.numpy(), mode="L")