diff --git a/.github/workflows/test-invoke-conda.yml b/.github/workflows/test-invoke-conda.yml index a144303cc3..e9a0719040 100644 --- a/.github/workflows/test-invoke-conda.yml +++ b/.github/workflows/test-invoke-conda.yml @@ -4,7 +4,6 @@ on: branches: - 'main' - 'development' - - 'fix-gh-actions-fork' pull_request: branches: - 'main' diff --git a/environment-linux-aarch64.yml b/environment-linux-aarch64.yml index de762f8b85..c7a76b821b 100644 --- a/environment-linux-aarch64.yml +++ b/environment-linux-aarch64.yml @@ -2,15 +2,16 @@ name: invokeai channels: - pytorch - conda-forge + - defaults dependencies: - - python=3.9.* + - python=3.10 - pip>=22.2.2 - cudatoolkit - pytorch - torchvision - - numpy=1.19 - - imageio=2.9.0 - - opencv=4.6.0 + - numpy=1.23 + - imageio=2.21 + - opencv=4.6 - pillow=8.* - flask=2.1.* - flask_cors=3.0.10 @@ -25,21 +26,18 @@ dependencies: - einops=0.3.0 - kornia=0.6 - torchmetrics=0.7.0 - - transformers=4.21.3 + - transformers=4.23 - torch-fidelity=0.3.0 - tokenizers>=0.11.1,!=0.11.3,<0.13 - pip: - getpass_asterisk - omegaconf==2.1.1 - - realesrgan==0.2.5.0 - - test-tube>=0.7.5 - pyreadline3 - - dependency_injector==4.40.0 - - -e git+https://github.com/openai/CLIP.git@main#egg=clip - - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - - -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion - - -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan - - -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg + - taming-transformers-rom1504 + - test-tube>=0.7.5 + - git+https://github.com/openai/CLIP.git@main#egg=clip + - git+https://github.com/invoke-ai/k-diffusion.git@mps#egg=k_diffusion + - git+https://github.com/invoke-ai/Real-ESRGAN.git#egg=realesrgan + - git+https://github.com/invoke-ai/GFPGAN.git#egg=gfpgan + - git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg - -e . -variables: - PYTORCH_ENABLE_MPS_FALLBACK: 1 diff --git a/environment-mac.yml b/environment-mac.yml index e0db02c3b9..29f5197be9 100644 --- a/environment-mac.yml +++ b/environment-mac.yml @@ -2,12 +2,15 @@ name: invokeai channels: - pytorch - conda-forge + - defaults dependencies: - - python=3.9.13 - - pip=22.2.2 - - - pytorch=1.12.1 - - torchvision=0.13.1 + - python=3.10 + - pip>=22.2 + - pytorch=1.12 + - pytorch-lightning=1.7 + - torchvision=0.13 + - torchmetrics=0.10 + - torch-fidelity=0.3 # I suggest to keep the other deps sorted for convenience. # To determine what the latest versions should be, run: @@ -17,49 +20,45 @@ dependencies: # CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac-updated.yml && conda list -n invokeai-updated | awk ' {print " - " $1 "==" $2;} ' # ``` - - albumentations=1.2.1 - - coloredlogs=15.0.1 - - diffusers=0.6.0 - - einops=0.4.1 - - grpcio=1.46.4 + - albumentations=1.2 + - coloredlogs=15.0 + - diffusers=0.6 + - einops=0.3 + - eventlet + - grpcio=1.46 + - flask=2.1 + - flask-socketio=5.3 + - flask-cors=3.0 - humanfriendly=10.0 - - imageio=2.21.2 - - imageio-ffmpeg=0.4.7 - - imgaug=0.4.0 - - kornia=0.6.7 - - mpmath=1.2.1 - - nomkl # arm64 has only 1.0 while x64 needs 3.0 - - numpy=1.23.4 - - omegaconf=2.1.1 - - openh264=2.3.0 - - onnx=1.12.0 - - onnxruntime=1.12.1 - - pudb=2022.1 - - pytorch-lightning=1.7.7 - - scipy=1.9.3 - - streamlit=1.12.2 - - sympy=1.10.1 - - tensorboard=2.10.0 - - torchmetrics=0.10.1 - - py-opencv=4.6.0 - - flask=2.1.3 - - flask-socketio=5.3.0 - - flask-cors=3.0.10 - - eventlet=0.33.1 - - protobuf=3.20.1 - - send2trash=1.8.0 - - transformers=4.23.1 - - torch-fidelity=0.3.0 + - imageio=2.21 + - imageio-ffmpeg=0.4 + - imgaug=0.4 + - kornia=0.6 + - mpmath=1.2 + - nomkl=3 + - numpy=1.23 + - omegaconf=2.1 + - openh264=2.3 + - onnx=1.12 + - onnxruntime=1.12 + - pudb=2019.2 + - protobuf=3.20 + - py-opencv=4.6 + - scipy=1.9 + - streamlit=1.12 + - sympy=1.10 + - send2trash=1.8 + - tensorboard=2.10 + - transformers=4.23 - pip: - getpass_asterisk - - dependency_injector==4.40.0 - - realesrgan==0.2.5.0 + - taming-transformers-rom1504 - test-tube==0.7.5 - - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - - -e git+https://github.com/openai/CLIP.git@main#egg=clip - - -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion - - -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan - - -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg + - git+https://github.com/openai/CLIP.git@main#egg=clip + - git+https://github.com/invoke-ai/k-diffusion.git@mps#egg=k_diffusion + - git+https://github.com/invoke-ai/Real-ESRGAN.git#egg=realesrgan + - git+https://github.com/invoke-ai/GFPGAN.git#egg=gfpgan + - git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg - -e . variables: PYTORCH_ENABLE_MPS_FALLBACK: 1 diff --git a/environment.yml b/environment.yml index ae07e11c3a..3d5c44d391 100644 --- a/environment.yml +++ b/environment.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: - - python>=3.9 + - python=3.10 - pip=22.2.2 - numpy=1.23.3 - torchvision=0.13.1 @@ -32,14 +32,13 @@ dependencies: - flask==2.1.3 - flask_socketio==5.3.0 - flask_cors==3.0.10 - - dependency_injector==4.40.0 - eventlet - getpass_asterisk - kornia==0.6.0 + - taming-transformers-rom1504 - git+https://github.com/openai/CLIP.git@main#egg=clip - - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - - -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion - - -e git+https://github.com/invoke-ai/Real-ESRGAN.git#egg=realesrgan - - -e git+https://github.com/invoke-ai/GFPGAN.git#egg=gfpgan - - -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg + - git+https://github.com/invoke-ai/k-diffusion.git@mps#egg=k_diffusion + - git+https://github.com/invoke-ai/Real-ESRGAN.git#egg=realesrgan + - git+https://github.com/invoke-ai/GFPGAN.git#egg=gfpgan + - git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg - -e . diff --git a/frontend/README.md b/frontend/README.md index f597cc6f23..4becbb221f 100644 --- a/frontend/README.md +++ b/frontend/README.md @@ -5,7 +5,7 @@ - `python scripts/dream.py --web` serves both frontend and backend at http://localhost:9090 -## Evironment +## Environment Install [node](https://nodejs.org/en/download/) (includes npm) and optionally [yarn](https://yarnpkg.com/getting-started/install). @@ -15,7 +15,7 @@ packages. ## Dev -1. From `frontend/`, run `npm dev` / `yarn dev` to start the dev server. +1. From `frontend/`, run `npm run dev` / `yarn dev` to start the dev server. 2. Run `python scripts/dream.py --web`. 3. Navigate to the dev server address e.g. `http://localhost:5173/`. diff --git a/ldm/generate.py b/ldm/generate.py index 28250695b5..920b9b9301 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -805,6 +805,10 @@ class Generate: # the model cache does the loading and offloading cache = self.model_cache + if not cache.valid_model(model_name): + print(f'** "{model_name}" is not a known model name. Please check your models.yaml file') + return self.model + cache.print_vram_usage() # have to get rid of all references to model in order @@ -1032,7 +1036,9 @@ class Generate: return True return False - def _check_for_erasure(self, image): + def _check_for_erasure(self, image:Image.Image)->bool: + if image.mode not in ('RGBA','RGB'): + return False width, height = image.size pixdata = image.load() colored = 0 diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 0392418d5d..5a2d7ae97c 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -247,8 +247,6 @@ class Args(object): switches.append('--seamless') if a['hires_fix']: switches.append('--hires_fix') - if a['strength'] and a['strength']>0: - switches.append(f'-f {a["strength"]}') # img2img generations have parameters relevant only to them and have special handling if a['init_img'] and len(a['init_img'])>0: diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index d95ad78196..759ba2dba4 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -10,8 +10,6 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.invoke.generator.omnibus import Omnibus from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from PIL import Image -from ldm.invoke.devices import choose_autocast -from ldm.invoke.image_util import InitImageResizer class Txt2Img2Img(Generator): def __init__(self, model, precision): @@ -46,13 +44,16 @@ class Txt2Img2Img(Generator): ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ) + #x = self.get_noise(init_width, init_height) + x = x_T + if self.free_gpu_mem and self.model.model.device != self.model.device: self.model.model.to(self.model.device) samples, _ = sampler.sample( batch_size = 1, S = steps, - x_T = x_T, + x_T = x, conditioning = c, shape = shape, verbose = False, @@ -68,21 +69,11 @@ class Txt2Img2Img(Generator): ) # resizing - - image = self.sample_to_image(samples) - image = InitImageResizer(image).resize(width, height) - - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - image = 2.0 * image - 1.0 - image = image.to(self.model.device) - - scope = choose_autocast(self.precision) - with scope(self.model.device.type): - samples = self.model.get_first_stage_encoding( - self.model.encode_first_stage(image) - ) # move back to latent space + samples = torch.nn.functional.interpolate( + samples, + size=(height // self.downsampling_factor, width // self.downsampling_factor), + mode="bilinear" + ) t_enc = int(strength * steps) ddim_sampler = DDIMSampler(self.model, device=self.model.device) diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 7b434941df..d4007c46de 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -41,15 +41,22 @@ class ModelCache(object): self.stack = [] # this is an LRU FIFO self.current_model = None + def valid_model(self, model_name:str)->bool: + ''' + Given a model name, returns True if it is a valid + identifier. + ''' + return model_name in self.config + def get_model(self, model_name:str): ''' Given a model named identified in models.yaml, return the model object. If in RAM will load into GPU VRAM. If on disk, will load from there. ''' - if model_name not in self.config: + if not self.valid_model(model_name): print(f'** "{model_name}" is not a known model name. Please check your models.yaml file') - return None + return self.current_model if self.current_model != model_name: if model_name not in self.models: # make room for a new one @@ -102,10 +109,13 @@ class ModelCache(object): Set the default model. The change will not take effect until you call model_cache.commit() ''' + print(f'DEBUG: before set_default_model()\n{OmegaConf.to_yaml(self.config)}') assert model_name in self.models,f"unknown model '{model_name}'" - for model in self.models: - self.models[model].pop('default',None) - self.models[model_name]['default'] = True + config = self.config + for model in config: + config[model].pop('default',None) + config[model_name]['default'] = True + print(f'DEBUG: after set_default_model():\n{OmegaConf.to_yaml(self.config)}') def list_models(self) -> dict: ''' diff --git a/ldm/invoke/readline.py b/ldm/invoke/readline.py index 7d87ede755..4e95e9b063 100644 --- a/ldm/invoke/readline.py +++ b/ldm/invoke/readline.py @@ -284,6 +284,7 @@ class Completer(object): switch,partial_path = match.groups() partial_path = partial_path.lstrip() + matches = list() path = os.path.expanduser(partial_path) @@ -321,6 +322,7 @@ class Completer(object): matches.append( switch+os.path.join(os.path.dirname(full_path), node) ) + return matches class DummyCompleter(Completer): diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 9c8c597869..ff90a24856 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -1,10 +1,13 @@ -from enum import Enum +import enum +from typing import Optional import torch # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl + + class CrossAttentionControl: class Arguments: @@ -27,7 +30,14 @@ class CrossAttentionControl: print('warning: cross-attention control options are not working properly for >1 edit') self.edit_options = non_none_edit_options[0] + class Context: + + class Action(enum.Enum): + NONE = 0 + SAVE = 1, + APPLY = 2 + def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int): """ :param arguments: Arguments for the cross-attention control process @@ -36,14 +46,124 @@ class CrossAttentionControl: self.arguments = arguments self.step_count = step_count + self.self_cross_attention_module_identifiers = [] + self.tokens_cross_attention_module_identifiers = [] + + self.saved_cross_attention_maps = {} + + self.clear_requests(cleanup=True) + + def register_cross_attention_modules(self, model): + for name,module in CrossAttentionControl.get_attention_modules(model, + CrossAttentionControl.CrossAttentionType.SELF): + self.self_cross_attention_module_identifiers.append(name) + for name,module in CrossAttentionControl.get_attention_modules(model, + CrossAttentionControl.CrossAttentionType.TOKENS): + self.tokens_cross_attention_module_identifiers.append(name) + + def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'): + if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF: + self.self_cross_attention_action = CrossAttentionControl.Context.Action.SAVE + else: + self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.SAVE + + def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'): + if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF: + self.self_cross_attention_action = CrossAttentionControl.Context.Action.APPLY + else: + self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.APPLY + + def is_tokens_cross_attention(self, module_identifier) -> bool: + return module_identifier in self.tokens_cross_attention_module_identifiers + + def get_should_save_maps(self, module_identifier: str) -> bool: + if module_identifier in self.self_cross_attention_module_identifiers: + return self.self_cross_attention_action == CrossAttentionControl.Context.Action.SAVE + elif module_identifier in self.tokens_cross_attention_module_identifiers: + return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.SAVE + return False + + def get_should_apply_saved_maps(self, module_identifier: str) -> bool: + if module_identifier in self.self_cross_attention_module_identifiers: + return self.self_cross_attention_action == CrossAttentionControl.Context.Action.APPLY + elif module_identifier in self.tokens_cross_attention_module_identifiers: + return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.APPLY + return False + + def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\ + -> list['CrossAttentionControl.CrossAttentionType']: + """ + Should cross-attention control be applied on the given step? + :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. + :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. + """ + if percent_through is None: + return [CrossAttentionControl.CrossAttentionType.SELF, CrossAttentionControl.CrossAttentionType.TOKENS] + + opts = self.arguments.edit_options + to_control = [] + if opts['s_start'] <= percent_through and percent_through < opts['s_end']: + to_control.append(CrossAttentionControl.CrossAttentionType.SELF) + if opts['t_start'] <= percent_through and percent_through < opts['t_end']: + to_control.append(CrossAttentionControl.CrossAttentionType.TOKENS) + return to_control + + def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int, + slice_size: Optional[int]): + if identifier not in self.saved_cross_attention_maps: + self.saved_cross_attention_maps[identifier] = { + 'dim': dim, + 'slice_size': slice_size, + 'slices': {offset or 0: slice} + } + else: + self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice + + def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int): + saved_attention_dict = self.saved_cross_attention_maps[identifier] + if requested_dim is None: + if saved_attention_dict['dim'] is not None: + raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}") + return saved_attention_dict['slices'][0] + + if saved_attention_dict['dim'] == requested_dim: + if slice_size != saved_attention_dict['slice_size']: + raise RuntimeError( + f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}") + return saved_attention_dict['slices'][requested_offset] + + if saved_attention_dict['dim'] == None: + whole_saved_attention = saved_attention_dict['slices'][0] + if requested_dim == 0: + return whole_saved_attention[requested_offset:requested_offset + slice_size] + elif requested_dim == 1: + return whole_saved_attention[:, requested_offset:requested_offset + slice_size] + + raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") + + def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]: + saved_attention = self.saved_cross_attention_maps.get(identifier, None) + if saved_attention is None: + return None, None + return saved_attention['dim'], saved_attention['slice_size'] + + def clear_requests(self, cleanup=True): + self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.NONE + self.self_cross_attention_action = CrossAttentionControl.Context.Action.NONE + if cleanup: + self.saved_cross_attention_maps = {} + + def offload_saved_attention_slices_to_cpu(self): + for key, map_dict in self.saved_cross_attention_maps.items(): + for offset, slice in map_dict['slices'].items(): + map_dict[offset] = slice.to('cpu') + @classmethod def remove_cross_attention_control(cls, model): cls.remove_attention_function(model) @classmethod - def setup_cross_attention_control(cls, model, - cross_attention_control_args: Arguments - ): + def setup_cross_attention_control(cls, model, context: Context): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. @@ -53,7 +173,7 @@ class CrossAttentionControl: """ # adapted from init_attention_edit - device = cross_attention_control_args.edited_conditioning.device + device = context.arguments.edited_conditioning.device # urgh. should this be hardcoded? max_length = 77 @@ -61,141 +181,82 @@ class CrossAttentionControl: mask = torch.zeros(max_length) indices_target = torch.arange(max_length, dtype=torch.long) indices = torch.zeros(max_length, dtype=torch.long) - for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes: + for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: if b0 < max_length: if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): # these tokens have not been edited indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 - cls.inject_attention_function(model) - - for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF): - m.last_attn_slice_mask = None - m.last_attn_slice_indices = None - - for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS): - m.last_attn_slice_mask = mask.to(device) - m.last_attn_slice_indices = indices.to(device) + context.register_cross_attention_modules(model) + context.cross_attention_mask = mask.to(device) + context.cross_attention_index_map = indices.to(device) + cls.inject_attention_function(model, context) - class CrossAttentionType(Enum): + class CrossAttentionType(enum.Enum): SELF = 1 TOKENS = 2 - @classmethod - def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\ - -> list['CrossAttentionControl.CrossAttentionType']: - """ - Should cross-attention control be applied on the given step? - :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. - :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. - """ - if percent_through is None: - return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS] - - opts = context.arguments.edit_options - to_control = [] - if opts['s_start'] <= percent_through and percent_through < opts['s_end']: - to_control.append(cls.CrossAttentionType.SELF) - if opts['t_start'] <= percent_through and percent_through < opts['t_end']: - to_control.append(cls.CrossAttentionType.TOKENS) - return to_control - - @classmethod def get_attention_modules(cls, model, which: CrossAttentionType): which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2" - return [module for name, module in model.named_modules() if + return [(name,module) for name, module in model.named_modules() if type(module).__name__ == "CrossAttention" and which_attn in name] - @classmethod - def clear_requests(cls, model, clear_attn_slice=True): - self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF) - tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS) - for m in self_attention_modules+tokens_attention_modules: - m.save_last_attn_slice = False - m.use_last_attn_slice = False - if clear_attn_slice: - m.last_attn_slice = None @classmethod - def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType): - modules = cls.get_attention_modules(model, cross_attention_type) - for m in modules: - # clear out the saved slice in case the outermost dim changes - m.last_attn_slice = None - m.save_last_attn_slice = True - - @classmethod - def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType): - modules = cls.get_attention_modules(model, cross_attention_type) - for m in modules: - m.use_last_attn_slice = True - - - - @classmethod - def inject_attention_function(cls, unet): + def inject_attention_function(cls, unet, context: 'CrossAttentionControl.Context'): # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 - def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): + def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size): - #print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim) + #memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement() - attn_slice = suggested_attention_slice - if dim is not None: - start = offset - end = start+slice_size - #print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") - #else: - # print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") + attention_slice = suggested_attention_slice - if self.use_last_attn_slice: - if dim is None: - last_attn_slice = self.last_attn_slice - # print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + if context.get_should_save_maps(module.identifier): + #print(module.identifier, "saving suggested_attention_slice of shape", + # suggested_attention_slice.shape, "dim", dim, "offset", offset) + slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice + context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size) + elif context.get_should_apply_saved_maps(module.identifier): + #print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset) + saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size) + + # slice may have been offloaded to CPU + saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device) + + if context.is_tokens_cross_attention(module.identifier): + index_map = context.cross_attention_index_map + remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) + this_attention_slice = suggested_attention_slice + + mask = context.cross_attention_mask + saved_mask = mask + this_mask = 1 - mask + attention_slice = remapped_saved_attention_slice * saved_mask + \ + this_attention_slice * this_mask else: - last_attn_slice = self.last_attn_slice[offset] - - if self.last_attn_slice_mask is None: # just use everything - attn_slice = last_attn_slice - else: - last_attn_slice_mask = self.last_attn_slice_mask - remapped_last_attn_slice = torch.index_select(last_attn_slice, -1, self.last_attn_slice_indices) + attention_slice = saved_attention_slice - this_attn_slice = attn_slice - this_attn_slice_mask = 1 - last_attn_slice_mask - attn_slice = this_attn_slice * this_attn_slice_mask + \ - remapped_last_attn_slice * last_attn_slice_mask - - if self.save_last_attn_slice: - if dim is None: - self.last_attn_slice = attn_slice - else: - if self.last_attn_slice is None: - self.last_attn_slice = { offset: attn_slice } - else: - self.last_attn_slice[offset] = attn_slice - - return attn_slice + return attention_slice for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention": - module.last_attn_slice = None - module.last_attn_slice_indices = None - module.last_attn_slice_mask = None - module.use_last_attn_weights = False - module.use_last_attn_slice = False - module.save_last_attn_slice = False + module.identifier = name module.set_attention_slice_wrangler(attention_slice_wrangler) + module.set_slicing_strategy_getter(lambda module, module_identifier=name: \ + context.get_slicing_strategy(module_identifier)) @classmethod def remove_attention_function(cls, unet): + # clear wrangler callback for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention": module.set_attention_slice_wrangler(None) + module.set_slicing_strategy_getter(None) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 5a9cc3eb74..0a18eb25c8 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,9 +1,11 @@ +import traceback from math import ceil from typing import Callable, Optional, Union import torch from ldm.models.diffusion.cross_attention_control import CrossAttentionControl +from ldm.modules.attention import get_mem_free_total class InvokeAIDiffuserComponent: @@ -34,7 +36,7 @@ class InvokeAIDiffuserComponent: """ self.model = model self.model_forward_callback = model_forward_callback - + self.cross_attention_control_context = None def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): self.conditioning = conditioning @@ -42,11 +44,7 @@ class InvokeAIDiffuserComponent: arguments=self.conditioning.cross_attention_control_args, step_count=step_count ) - CrossAttentionControl.setup_cross_attention_control(self.model, - cross_attention_control_args=self.conditioning.cross_attention_control_args - ) - #todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct - #todo: apply edit_options using step_count + CrossAttentionControl.setup_cross_attention_control(self.model, self.cross_attention_control_context) def remove_cross_attention_control(self): self.conditioning = None @@ -54,6 +52,7 @@ class InvokeAIDiffuserComponent: CrossAttentionControl.remove_cross_attention_control(self.model) + def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, unconditioning: Union[torch.Tensor,dict], conditioning: Union[torch.Tensor,dict], @@ -70,12 +69,12 @@ class InvokeAIDiffuserComponent: :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. """ - CrossAttentionControl.clear_requests(self.model) cross_attention_control_types_to_do = [] + context: CrossAttentionControl.Context = self.cross_attention_control_context if self.cross_attention_control_context is not None: percent_through = self.estimate_percent_through(step_index, sigma) - cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through) + cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) wants_hybrid_conditioning = isinstance(conditioning, dict) @@ -124,7 +123,7 @@ class InvokeAIDiffuserComponent: return unconditioned_next_x, conditioned_next_x - def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): + def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of @@ -134,32 +133,32 @@ class InvokeAIDiffuserComponent: # representing batched uncond + cond, but then when it comes to applying the saved attention, the # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. + context:CrossAttentionControl.Context = self.cross_attention_control_context try: unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) # process x using the original prompt, saving the attention maps - for type in cross_attention_control_types_to_do: - CrossAttentionControl.request_save_attention_maps(self.model, type) + #print("saving attention maps for", cross_attention_control_types_to_do) + for ca_type in cross_attention_control_types_to_do: + context.request_save_attention_maps(ca_type) _ = self.model_forward_callback(x, sigma, conditioning) - CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False) + context.clear_requests(cleanup=False) # process x again, using the saved attention maps to control where self.edited_conditioning will be applied - for type in cross_attention_control_types_to_do: - CrossAttentionControl.request_apply_saved_attention_maps(self.model, type) + #print("applying saved attention maps for", cross_attention_control_types_to_do) + for ca_type in cross_attention_control_types_to_do: + context.request_apply_saved_attention_maps(ca_type) edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) + context.clear_requests(cleanup=True) - CrossAttentionControl.clear_requests(self.model) - - return unconditioned_next_x, conditioned_next_x - - except RuntimeError: - # make sure we clean out the attention slices we're storing on the model - # TODO don't store things on the model - CrossAttentionControl.clear_requests(self.model) + except: + context.clear_requests(cleanup=True) raise + return unconditioned_next_x, conditioned_next_x + def estimate_percent_through(self, step_index, sigma): if step_index is not None and self.cross_attention_control_context is not None: # percent_through will never reach 1.0 (but this is intended) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 4c36fa8a6c..94bb8a2916 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,6 +1,6 @@ from inspect import isfunction import math -from typing import Callable +from typing import Callable, Optional import torch import torch.nn.functional as F @@ -151,6 +151,17 @@ class SpatialSelfAttention(nn.Module): return x+h_ +def get_mem_free_total(device): + #only on cuda + if not torch.cuda.is_available(): + return None + stats = torch.cuda.memory_stats(device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(device) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + return mem_free_total class CrossAttention(nn.Module): @@ -173,31 +184,43 @@ class CrossAttention(nn.Module): self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) + self.cached_mem_free_total = None self.attention_slice_wrangler = None + self.slicing_strategy_getter = None - def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]): + def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]): ''' Set custom attention calculator to be called when attention is calculated - :param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size), + :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), which returns either the suggested_attention_slice or an adjusted equivalent. - self is the current CrossAttention module for which the callback is being invoked. - attention_scores are the scores for attention - suggested_attention_slice is a softmax(dim=-1) over attention_scores - dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. - If dim is >= 0, offset and slice_size specify the slice start and length. + `module` is the current CrossAttention module for which the callback is being invoked. + `suggested_attention_slice` is the default-calculated attention slice + `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. + If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length. Pass None to use the default attention calculation. :return: ''' self.attention_slice_wrangler = wrangler + def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]): + self.slicing_strategy_getter = getter + + def cache_free_memory_count(self, device): + self.cached_mem_free_total = get_mem_free_total(device) + print("free cuda memory: ", self.cached_mem_free_total) + + def clear_cached_free_memory_count(self): + self.cached_mem_free_total = None + def einsum_lowest_level(self, q, k, v, dim, offset, slice_size): # calculate attention scores attention_scores = einsum('b i d, b j d -> b i j', q, k) - # calculate attenion slice by taking the best scores for each latent pixel + # calculate attention slice by taking the best scores for each latent pixel default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) - if self.attention_slice_wrangler is not None: - attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size) + attention_slice_wrangler = self.attention_slice_wrangler + if attention_slice_wrangler is not None: + attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size) else: attention_slice = default_attention_slice @@ -240,17 +263,26 @@ class CrossAttention(nn.Module): return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) def einsum_op_cuda(self, q, k, v): - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + # check if we already have a slicing strategy (this should only happen during cross-attention controlled generation) + slicing_strategy_getter = self.slicing_strategy_getter + if slicing_strategy_getter is not None: + (dim, slice_size) = slicing_strategy_getter(self) + if dim is not None: + # print("using saved slicing strategy with dim", dim, "slice size", slice_size) + if dim == 0: + return self.einsum_op_slice_dim0(q, k, v, slice_size) + elif dim == 1: + return self.einsum_op_slice_dim1(q, k, v, slice_size) + + # fallback for when there is no saved strategy, or saved strategy does not slice + mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device) # Divide factor of safety as there's copying and fragmentation return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + def get_attention_mem_efficient(self, q, k, v): if q.device.type == 'cuda': + #print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device)) return self.einsum_op_cuda(q, k, v) if q.device.type == 'mps': diff --git a/requirements.txt b/requirements.txt index 939463e36e..fce5c87abf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,4 +38,4 @@ git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/invoke-ai/Real-ESRGAN.git#egg=realesrgan git+https://github.com/invoke-ai/GFPGAN.git#egg=gfpgan --e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg +git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg diff --git a/scripts/invoke.py b/scripts/invoke.py index 5a474a5463..0417eacde6 100755 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -90,7 +90,12 @@ def main(): safety_checker=opt.safety_checker, max_loaded_models=opt.max_loaded_models, ) - except (FileNotFoundError, IOError, KeyError) as e: + except FileNotFoundError: + print('** You appear to be missing configs/models.yaml') + print('** You can either exit this script and run scripts/preload_models.py, or fix the problem now.') + emergency_model_create(opt) + sys.exit(-1) + except (IOError, KeyError) as e: print(f'{e}. Aborting.') sys.exit(-1) @@ -485,6 +490,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple: command = '-h' return command, operation + def add_weights_to_config(model_path:str, gen, opt, completer): print(f'>> Model import in process. Please enter the values needed to configure this model:') print() @@ -581,7 +587,7 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak try: print('>> Verifying that new model loads...') - yaml_str = gen.model_cache.add_model(model_name, new_config, clobber) + gen.model_cache.add_model(model_name, new_config, clobber) assert gen.set_model(model_name) is not None, 'model failed to load' except AssertionError as e: print(f'** aborting **') @@ -894,6 +900,36 @@ def write_commands(opt, file_path:str, outfilepath:str): f.write('\n'.join(commands)) print(f'>> File {outfilepath} with commands created') +def emergency_model_create(opt:Args): + completer = get_completer(opt) + completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt')) + completer.set_default_dir('.') + valid_path = False + while not valid_path: + weights_file = input('Enter the path to a downloaded models file, or ^C to exit: ') + valid_path = os.path.exists(weights_file) + dir,basename = os.path.split(weights_file) + + valid_name = False + while not valid_name: + name = input('Enter a short name for this model (no spaces): ') + name = 'unnamed model' if len(name)==0 else name + valid_name = ' ' not in name + + description = input('Enter a description for this model: ') + description = 'no description' if len(description)==0 else description + + with open(opt.conf, 'w', encoding='utf-8') as f: + f.write(f'{name}:\n') + f.write(f' description: {description}\n') + f.write(f' weights: {weights_file}\n') + f.write(f' config: ./configs/stable-diffusion/v1-inference.yaml\n') + f.write(f' width: 512\n') + f.write(f' height: 512\n') + f.write(f' default: true\n') + print(f'Config file {opt.conf} is created. This script will now exit.') + print(f'After restarting you may examine the entry with !models and edit it with !edit.') + ###################################### if __name__ == '__main__':