mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into lstein-outcrop-improvements
This commit is contained in:
commit
9c218788e2
1
.github/workflows/test-invoke-conda.yml
vendored
1
.github/workflows/test-invoke-conda.yml
vendored
@ -4,7 +4,6 @@ on:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'development'
|
||||
- 'fix-gh-actions-fork'
|
||||
pull_request:
|
||||
branches:
|
||||
- 'main'
|
||||
|
@ -2,15 +2,16 @@ name: invokeai
|
||||
channels:
|
||||
- pytorch
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.9.*
|
||||
- python=3.10
|
||||
- pip>=22.2.2
|
||||
- cudatoolkit
|
||||
- pytorch
|
||||
- torchvision
|
||||
- numpy=1.19
|
||||
- imageio=2.9.0
|
||||
- opencv=4.6.0
|
||||
- numpy=1.23
|
||||
- imageio=2.21
|
||||
- opencv=4.6
|
||||
- pillow=8.*
|
||||
- flask=2.1.*
|
||||
- flask_cors=3.0.10
|
||||
@ -25,21 +26,18 @@ dependencies:
|
||||
- einops=0.3.0
|
||||
- kornia=0.6
|
||||
- torchmetrics=0.7.0
|
||||
- transformers=4.21.3
|
||||
- transformers=4.23
|
||||
- torch-fidelity=0.3.0
|
||||
- tokenizers>=0.11.1,!=0.11.3,<0.13
|
||||
- pip:
|
||||
- getpass_asterisk
|
||||
- omegaconf==2.1.1
|
||||
- realesrgan==0.2.5.0
|
||||
- test-tube>=0.7.5
|
||||
- pyreadline3
|
||||
- dependency_injector==4.40.0
|
||||
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
|
||||
- -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
|
||||
- -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
||||
- taming-transformers-rom1504
|
||||
- test-tube>=0.7.5
|
||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
- git+https://github.com/invoke-ai/k-diffusion.git@mps#egg=k_diffusion
|
||||
- git+https://github.com/invoke-ai/Real-ESRGAN.git#egg=realesrgan
|
||||
- git+https://github.com/invoke-ai/GFPGAN.git#egg=gfpgan
|
||||
- git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
||||
- -e .
|
||||
variables:
|
||||
PYTORCH_ENABLE_MPS_FALLBACK: 1
|
||||
|
@ -2,12 +2,15 @@ name: invokeai
|
||||
channels:
|
||||
- pytorch
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.9.13
|
||||
- pip=22.2.2
|
||||
|
||||
- pytorch=1.12.1
|
||||
- torchvision=0.13.1
|
||||
- python=3.10
|
||||
- pip>=22.2
|
||||
- pytorch=1.12
|
||||
- pytorch-lightning=1.7
|
||||
- torchvision=0.13
|
||||
- torchmetrics=0.10
|
||||
- torch-fidelity=0.3
|
||||
|
||||
# I suggest to keep the other deps sorted for convenience.
|
||||
# To determine what the latest versions should be, run:
|
||||
@ -17,49 +20,45 @@ dependencies:
|
||||
# CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac-updated.yml && conda list -n invokeai-updated | awk ' {print " - " $1 "==" $2;} '
|
||||
# ```
|
||||
|
||||
- albumentations=1.2.1
|
||||
- coloredlogs=15.0.1
|
||||
- diffusers=0.6.0
|
||||
- einops=0.4.1
|
||||
- grpcio=1.46.4
|
||||
- albumentations=1.2
|
||||
- coloredlogs=15.0
|
||||
- diffusers=0.6
|
||||
- einops=0.3
|
||||
- eventlet
|
||||
- grpcio=1.46
|
||||
- flask=2.1
|
||||
- flask-socketio=5.3
|
||||
- flask-cors=3.0
|
||||
- humanfriendly=10.0
|
||||
- imageio=2.21.2
|
||||
- imageio-ffmpeg=0.4.7
|
||||
- imgaug=0.4.0
|
||||
- kornia=0.6.7
|
||||
- mpmath=1.2.1
|
||||
- nomkl # arm64 has only 1.0 while x64 needs 3.0
|
||||
- numpy=1.23.4
|
||||
- omegaconf=2.1.1
|
||||
- openh264=2.3.0
|
||||
- onnx=1.12.0
|
||||
- onnxruntime=1.12.1
|
||||
- pudb=2022.1
|
||||
- pytorch-lightning=1.7.7
|
||||
- scipy=1.9.3
|
||||
- streamlit=1.12.2
|
||||
- sympy=1.10.1
|
||||
- tensorboard=2.10.0
|
||||
- torchmetrics=0.10.1
|
||||
- py-opencv=4.6.0
|
||||
- flask=2.1.3
|
||||
- flask-socketio=5.3.0
|
||||
- flask-cors=3.0.10
|
||||
- eventlet=0.33.1
|
||||
- protobuf=3.20.1
|
||||
- send2trash=1.8.0
|
||||
- transformers=4.23.1
|
||||
- torch-fidelity=0.3.0
|
||||
- imageio=2.21
|
||||
- imageio-ffmpeg=0.4
|
||||
- imgaug=0.4
|
||||
- kornia=0.6
|
||||
- mpmath=1.2
|
||||
- nomkl=3
|
||||
- numpy=1.23
|
||||
- omegaconf=2.1
|
||||
- openh264=2.3
|
||||
- onnx=1.12
|
||||
- onnxruntime=1.12
|
||||
- pudb=2019.2
|
||||
- protobuf=3.20
|
||||
- py-opencv=4.6
|
||||
- scipy=1.9
|
||||
- streamlit=1.12
|
||||
- sympy=1.10
|
||||
- send2trash=1.8
|
||||
- tensorboard=2.10
|
||||
- transformers=4.23
|
||||
- pip:
|
||||
- getpass_asterisk
|
||||
- dependency_injector==4.40.0
|
||||
- realesrgan==0.2.5.0
|
||||
- taming-transformers-rom1504
|
||||
- test-tube==0.7.5
|
||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
|
||||
- -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
|
||||
- -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
- git+https://github.com/invoke-ai/k-diffusion.git@mps#egg=k_diffusion
|
||||
- git+https://github.com/invoke-ai/Real-ESRGAN.git#egg=realesrgan
|
||||
- git+https://github.com/invoke-ai/GFPGAN.git#egg=gfpgan
|
||||
- git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
||||
- -e .
|
||||
variables:
|
||||
PYTORCH_ENABLE_MPS_FALLBACK: 1
|
||||
|
@ -4,7 +4,7 @@ channels:
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python>=3.9
|
||||
- python=3.10
|
||||
- pip=22.2.2
|
||||
- numpy=1.23.3
|
||||
- torchvision=0.13.1
|
||||
@ -32,14 +32,13 @@ dependencies:
|
||||
- flask==2.1.3
|
||||
- flask_socketio==5.3.0
|
||||
- flask_cors==3.0.10
|
||||
- dependency_injector==4.40.0
|
||||
- eventlet
|
||||
- getpass_asterisk
|
||||
- kornia==0.6.0
|
||||
- taming-transformers-rom1504
|
||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
|
||||
- -e git+https://github.com/invoke-ai/Real-ESRGAN.git#egg=realesrgan
|
||||
- -e git+https://github.com/invoke-ai/GFPGAN.git#egg=gfpgan
|
||||
- -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
||||
- git+https://github.com/invoke-ai/k-diffusion.git@mps#egg=k_diffusion
|
||||
- git+https://github.com/invoke-ai/Real-ESRGAN.git#egg=realesrgan
|
||||
- git+https://github.com/invoke-ai/GFPGAN.git#egg=gfpgan
|
||||
- git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
||||
- -e .
|
||||
|
@ -5,7 +5,7 @@
|
||||
- `python scripts/dream.py --web` serves both frontend and backend at
|
||||
http://localhost:9090
|
||||
|
||||
## Evironment
|
||||
## Environment
|
||||
|
||||
Install [node](https://nodejs.org/en/download/) (includes npm) and optionally
|
||||
[yarn](https://yarnpkg.com/getting-started/install).
|
||||
@ -15,7 +15,7 @@ packages.
|
||||
|
||||
## Dev
|
||||
|
||||
1. From `frontend/`, run `npm dev` / `yarn dev` to start the dev server.
|
||||
1. From `frontend/`, run `npm run dev` / `yarn dev` to start the dev server.
|
||||
2. Run `python scripts/dream.py --web`.
|
||||
3. Navigate to the dev server address e.g. `http://localhost:5173/`.
|
||||
|
||||
|
@ -805,6 +805,10 @@ class Generate:
|
||||
|
||||
# the model cache does the loading and offloading
|
||||
cache = self.model_cache
|
||||
if not cache.valid_model(model_name):
|
||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return self.model
|
||||
|
||||
cache.print_vram_usage()
|
||||
|
||||
# have to get rid of all references to model in order
|
||||
@ -1032,7 +1036,9 @@ class Generate:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_for_erasure(self, image):
|
||||
def _check_for_erasure(self, image:Image.Image)->bool:
|
||||
if image.mode not in ('RGBA','RGB'):
|
||||
return False
|
||||
width, height = image.size
|
||||
pixdata = image.load()
|
||||
colored = 0
|
||||
|
@ -247,8 +247,6 @@ class Args(object):
|
||||
switches.append('--seamless')
|
||||
if a['hires_fix']:
|
||||
switches.append('--hires_fix')
|
||||
if a['strength'] and a['strength']>0:
|
||||
switches.append(f'-f {a["strength"]}')
|
||||
|
||||
# img2img generations have parameters relevant only to them and have special handling
|
||||
if a['init_img'] and len(a['init_img'])>0:
|
||||
|
@ -10,8 +10,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.invoke.generator.omnibus import Omnibus
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from PIL import Image
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.image_util import InitImageResizer
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -46,13 +44,16 @@ class Txt2Img2Img(Generator):
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
#x = self.get_noise(init_width, init_height)
|
||||
x = x_T
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x_T,
|
||||
x_T = x,
|
||||
conditioning = c,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
@ -68,21 +69,11 @@ class Txt2Img2Img(Generator):
|
||||
)
|
||||
|
||||
# resizing
|
||||
|
||||
image = self.sample_to_image(samples)
|
||||
image = InitImageResizer(image).resize(width, height)
|
||||
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
image = 2.0 * image - 1.0
|
||||
image = image.to(self.model.device)
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
samples = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(image)
|
||||
) # move back to latent space
|
||||
samples = torch.nn.functional.interpolate(
|
||||
samples,
|
||||
size=(height // self.downsampling_factor, width // self.downsampling_factor),
|
||||
mode="bilinear"
|
||||
)
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
ddim_sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
|
@ -41,15 +41,22 @@ class ModelCache(object):
|
||||
self.stack = [] # this is an LRU FIFO
|
||||
self.current_model = None
|
||||
|
||||
def valid_model(self, model_name:str)->bool:
|
||||
'''
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
'''
|
||||
return model_name in self.config
|
||||
|
||||
def get_model(self, model_name:str):
|
||||
'''
|
||||
Given a model named identified in models.yaml, return
|
||||
the model object. If in RAM will load into GPU VRAM.
|
||||
If on disk, will load from there.
|
||||
'''
|
||||
if model_name not in self.config:
|
||||
if not self.valid_model(model_name):
|
||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return None
|
||||
return self.current_model
|
||||
|
||||
if self.current_model != model_name:
|
||||
if model_name not in self.models: # make room for a new one
|
||||
@ -102,10 +109,13 @@ class ModelCache(object):
|
||||
Set the default model. The change will not take
|
||||
effect until you call model_cache.commit()
|
||||
'''
|
||||
print(f'DEBUG: before set_default_model()\n{OmegaConf.to_yaml(self.config)}')
|
||||
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||
for model in self.models:
|
||||
self.models[model].pop('default',None)
|
||||
self.models[model_name]['default'] = True
|
||||
config = self.config
|
||||
for model in config:
|
||||
config[model].pop('default',None)
|
||||
config[model_name]['default'] = True
|
||||
print(f'DEBUG: after set_default_model():\n{OmegaConf.to_yaml(self.config)}')
|
||||
|
||||
def list_models(self) -> dict:
|
||||
'''
|
||||
|
@ -284,6 +284,7 @@ class Completer(object):
|
||||
switch,partial_path = match.groups()
|
||||
partial_path = partial_path.lstrip()
|
||||
|
||||
|
||||
matches = list()
|
||||
path = os.path.expanduser(partial_path)
|
||||
|
||||
@ -321,6 +322,7 @@ class Completer(object):
|
||||
matches.append(
|
||||
switch+os.path.join(os.path.dirname(full_path), node)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
class DummyCompleter(Completer):
|
||||
|
@ -1,10 +1,13 @@
|
||||
from enum import Enum
|
||||
import enum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
|
||||
|
||||
class CrossAttentionControl:
|
||||
|
||||
class Arguments:
|
||||
@ -27,7 +30,14 @@ class CrossAttentionControl:
|
||||
print('warning: cross-attention control options are not working properly for >1 edit')
|
||||
self.edit_options = non_none_edit_options[0]
|
||||
|
||||
|
||||
class Context:
|
||||
|
||||
class Action(enum.Enum):
|
||||
NONE = 0
|
||||
SAVE = 1,
|
||||
APPLY = 2
|
||||
|
||||
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
|
||||
"""
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
@ -36,14 +46,124 @@ class CrossAttentionControl:
|
||||
self.arguments = arguments
|
||||
self.step_count = step_count
|
||||
|
||||
self.self_cross_attention_module_identifiers = []
|
||||
self.tokens_cross_attention_module_identifiers = []
|
||||
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
self.clear_requests(cleanup=True)
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name,module in CrossAttentionControl.get_attention_modules(model,
|
||||
CrossAttentionControl.CrossAttentionType.SELF):
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name,module in CrossAttentionControl.get_attention_modules(model,
|
||||
CrossAttentionControl.CrossAttentionType.TOKENS):
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
|
||||
def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'):
|
||||
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = CrossAttentionControl.Context.Action.SAVE
|
||||
else:
|
||||
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.SAVE
|
||||
|
||||
def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'):
|
||||
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = CrossAttentionControl.Context.Action.APPLY
|
||||
else:
|
||||
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.APPLY
|
||||
|
||||
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
||||
return module_identifier in self.tokens_cross_attention_module_identifiers
|
||||
|
||||
def get_should_save_maps(self, module_identifier: str) -> bool:
|
||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.SAVE
|
||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.SAVE
|
||||
return False
|
||||
|
||||
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.APPLY
|
||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.APPLY
|
||||
return False
|
||||
|
||||
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
|
||||
-> list['CrossAttentionControl.CrossAttentionType']:
|
||||
"""
|
||||
Should cross-attention control be applied on the given step?
|
||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||
"""
|
||||
if percent_through is None:
|
||||
return [CrossAttentionControl.CrossAttentionType.SELF, CrossAttentionControl.CrossAttentionType.TOKENS]
|
||||
|
||||
opts = self.arguments.edit_options
|
||||
to_control = []
|
||||
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
||||
to_control.append(CrossAttentionControl.CrossAttentionType.SELF)
|
||||
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
|
||||
to_control.append(CrossAttentionControl.CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
|
||||
slice_size: Optional[int]):
|
||||
if identifier not in self.saved_cross_attention_maps:
|
||||
self.saved_cross_attention_maps[identifier] = {
|
||||
'dim': dim,
|
||||
'slice_size': slice_size,
|
||||
'slices': {offset or 0: slice}
|
||||
}
|
||||
else:
|
||||
self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice
|
||||
|
||||
def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int):
|
||||
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
||||
if requested_dim is None:
|
||||
if saved_attention_dict['dim'] is not None:
|
||||
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
||||
return saved_attention_dict['slices'][0]
|
||||
|
||||
if saved_attention_dict['dim'] == requested_dim:
|
||||
if slice_size != saved_attention_dict['slice_size']:
|
||||
raise RuntimeError(
|
||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
|
||||
return saved_attention_dict['slices'][requested_offset]
|
||||
|
||||
if saved_attention_dict['dim'] == None:
|
||||
whole_saved_attention = saved_attention_dict['slices'][0]
|
||||
if requested_dim == 0:
|
||||
return whole_saved_attention[requested_offset:requested_offset + slice_size]
|
||||
elif requested_dim == 1:
|
||||
return whole_saved_attention[:, requested_offset:requested_offset + slice_size]
|
||||
|
||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||
|
||||
def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]:
|
||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||
if saved_attention is None:
|
||||
return None, None
|
||||
return saved_attention['dim'], saved_attention['slice_size']
|
||||
|
||||
def clear_requests(self, cleanup=True):
|
||||
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.NONE
|
||||
self.self_cross_attention_action = CrossAttentionControl.Context.Action.NONE
|
||||
if cleanup:
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
def offload_saved_attention_slices_to_cpu(self):
|
||||
for key, map_dict in self.saved_cross_attention_maps.items():
|
||||
for offset, slice in map_dict['slices'].items():
|
||||
map_dict[offset] = slice.to('cpu')
|
||||
|
||||
@classmethod
|
||||
def remove_cross_attention_control(cls, model):
|
||||
cls.remove_attention_function(model)
|
||||
|
||||
@classmethod
|
||||
def setup_cross_attention_control(cls, model,
|
||||
cross_attention_control_args: Arguments
|
||||
):
|
||||
def setup_cross_attention_control(cls, model, context: Context):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
@ -53,7 +173,7 @@ class CrossAttentionControl:
|
||||
"""
|
||||
|
||||
# adapted from init_attention_edit
|
||||
device = cross_attention_control_args.edited_conditioning.device
|
||||
device = context.arguments.edited_conditioning.device
|
||||
|
||||
# urgh. should this be hardcoded?
|
||||
max_length = 77
|
||||
@ -61,141 +181,82 @@ class CrossAttentionControl:
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.zeros(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes:
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
cls.inject_attention_function(model)
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
|
||||
m.last_attn_slice_mask = None
|
||||
m.last_attn_slice_indices = None
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS):
|
||||
m.last_attn_slice_mask = mask.to(device)
|
||||
m.last_attn_slice_indices = indices.to(device)
|
||||
context.register_cross_attention_modules(model)
|
||||
context.cross_attention_mask = mask.to(device)
|
||||
context.cross_attention_index_map = indices.to(device)
|
||||
cls.inject_attention_function(model, context)
|
||||
|
||||
|
||||
class CrossAttentionType(Enum):
|
||||
class CrossAttentionType(enum.Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
@classmethod
|
||||
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
|
||||
-> list['CrossAttentionControl.CrossAttentionType']:
|
||||
"""
|
||||
Should cross-attention control be applied on the given step?
|
||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||
"""
|
||||
if percent_through is None:
|
||||
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
|
||||
|
||||
opts = context.arguments.edit_options
|
||||
to_control = []
|
||||
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
||||
to_control.append(cls.CrossAttentionType.SELF)
|
||||
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
|
||||
to_control.append(cls.CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_attention_modules(cls, model, which: CrossAttentionType):
|
||||
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
|
||||
return [module for name, module in model.named_modules() if
|
||||
return [(name,module) for name, module in model.named_modules() if
|
||||
type(module).__name__ == "CrossAttention" and which_attn in name]
|
||||
|
||||
@classmethod
|
||||
def clear_requests(cls, model, clear_attn_slice=True):
|
||||
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
|
||||
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
|
||||
for m in self_attention_modules+tokens_attention_modules:
|
||||
m.save_last_attn_slice = False
|
||||
m.use_last_attn_slice = False
|
||||
if clear_attn_slice:
|
||||
m.last_attn_slice = None
|
||||
|
||||
@classmethod
|
||||
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
||||
modules = cls.get_attention_modules(model, cross_attention_type)
|
||||
for m in modules:
|
||||
# clear out the saved slice in case the outermost dim changes
|
||||
m.last_attn_slice = None
|
||||
m.save_last_attn_slice = True
|
||||
|
||||
@classmethod
|
||||
def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
||||
modules = cls.get_attention_modules(model, cross_attention_type)
|
||||
for m in modules:
|
||||
m.use_last_attn_slice = True
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def inject_attention_function(cls, unet):
|
||||
def inject_attention_function(cls, unet, context: 'CrossAttentionControl.Context'):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
||||
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
|
||||
|
||||
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
|
||||
#memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
||||
|
||||
attn_slice = suggested_attention_slice
|
||||
if dim is not None:
|
||||
start = offset
|
||||
end = start+slice_size
|
||||
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
#else:
|
||||
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
attention_slice = suggested_attention_slice
|
||||
|
||||
if self.use_last_attn_slice:
|
||||
if dim is None:
|
||||
last_attn_slice = self.last_attn_slice
|
||||
# print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
if context.get_should_save_maps(module.identifier):
|
||||
#print(module.identifier, "saving suggested_attention_slice of shape",
|
||||
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
||||
slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice
|
||||
context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size)
|
||||
elif context.get_should_apply_saved_maps(module.identifier):
|
||||
#print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
||||
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
||||
|
||||
# slice may have been offloaded to CPU
|
||||
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
||||
|
||||
if context.is_tokens_cross_attention(module.identifier):
|
||||
index_map = context.cross_attention_index_map
|
||||
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
||||
this_attention_slice = suggested_attention_slice
|
||||
|
||||
mask = context.cross_attention_mask
|
||||
saved_mask = mask
|
||||
this_mask = 1 - mask
|
||||
attention_slice = remapped_saved_attention_slice * saved_mask + \
|
||||
this_attention_slice * this_mask
|
||||
else:
|
||||
last_attn_slice = self.last_attn_slice[offset]
|
||||
|
||||
if self.last_attn_slice_mask is None:
|
||||
# just use everything
|
||||
attn_slice = last_attn_slice
|
||||
else:
|
||||
last_attn_slice_mask = self.last_attn_slice_mask
|
||||
remapped_last_attn_slice = torch.index_select(last_attn_slice, -1, self.last_attn_slice_indices)
|
||||
attention_slice = saved_attention_slice
|
||||
|
||||
this_attn_slice = attn_slice
|
||||
this_attn_slice_mask = 1 - last_attn_slice_mask
|
||||
attn_slice = this_attn_slice * this_attn_slice_mask + \
|
||||
remapped_last_attn_slice * last_attn_slice_mask
|
||||
|
||||
if self.save_last_attn_slice:
|
||||
if dim is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
else:
|
||||
if self.last_attn_slice is None:
|
||||
self.last_attn_slice = { offset: attn_slice }
|
||||
else:
|
||||
self.last_attn_slice[offset] = attn_slice
|
||||
|
||||
return attn_slice
|
||||
return attention_slice
|
||||
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.last_attn_slice = None
|
||||
module.last_attn_slice_indices = None
|
||||
module.last_attn_slice_mask = None
|
||||
module.use_last_attn_weights = False
|
||||
module.use_last_attn_slice = False
|
||||
module.save_last_attn_slice = False
|
||||
module.identifier = name
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
|
||||
context.get_slicing_strategy(module_identifier))
|
||||
|
||||
@classmethod
|
||||
def remove_attention_function(cls, unet):
|
||||
# clear wrangler callback
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.set_attention_slice_wrangler(None)
|
||||
module.set_slicing_strategy_getter(None)
|
||||
|
||||
|
@ -1,9 +1,11 @@
|
||||
import traceback
|
||||
from math import ceil
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
|
||||
from ldm.modules.attention import get_mem_free_total
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
@ -34,7 +36,7 @@ class InvokeAIDiffuserComponent:
|
||||
"""
|
||||
self.model = model
|
||||
self.model_forward_callback = model_forward_callback
|
||||
|
||||
self.cross_attention_control_context = None
|
||||
|
||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
|
||||
self.conditioning = conditioning
|
||||
@ -42,11 +44,7 @@ class InvokeAIDiffuserComponent:
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
)
|
||||
CrossAttentionControl.setup_cross_attention_control(self.model,
|
||||
cross_attention_control_args=self.conditioning.cross_attention_control_args
|
||||
)
|
||||
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
|
||||
#todo: apply edit_options using step_count
|
||||
CrossAttentionControl.setup_cross_attention_control(self.model, self.cross_attention_control_context)
|
||||
|
||||
def remove_cross_attention_control(self):
|
||||
self.conditioning = None
|
||||
@ -54,6 +52,7 @@ class InvokeAIDiffuserComponent:
|
||||
CrossAttentionControl.remove_cross_attention_control(self.model)
|
||||
|
||||
|
||||
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor,dict],
|
||||
conditioning: Union[torch.Tensor,dict],
|
||||
@ -70,12 +69,12 @@ class InvokeAIDiffuserComponent:
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
cross_attention_control_types_to_do = []
|
||||
context: CrossAttentionControl.Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
|
||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||
|
||||
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||
@ -124,7 +123,7 @@ class InvokeAIDiffuserComponent:
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||
def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
@ -134,32 +133,32 @@ class InvokeAIDiffuserComponent:
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
context:CrossAttentionControl.Context = self.cross_attention_control_context
|
||||
|
||||
try:
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
for type in cross_attention_control_types_to_do:
|
||||
CrossAttentionControl.request_save_attention_maps(self.model, type)
|
||||
#print("saving attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_save_attention_maps(ca_type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||
CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False)
|
||||
context.clear_requests(cleanup=False)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
for type in cross_attention_control_types_to_do:
|
||||
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
|
||||
#print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_apply_saved_attention_maps(ca_type)
|
||||
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||
context.clear_requests(cleanup=True)
|
||||
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
except RuntimeError:
|
||||
# make sure we clean out the attention slices we're storing on the model
|
||||
# TODO don't store things on the model
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
except:
|
||||
context.clear_requests(cleanup=True)
|
||||
raise
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def estimate_percent_through(self, step_index, sigma):
|
||||
if step_index is not None and self.cross_attention_control_context is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -151,6 +151,17 @@ class SpatialSelfAttention(nn.Module):
|
||||
|
||||
return x+h_
|
||||
|
||||
def get_mem_free_total(device):
|
||||
#only on cuda
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
stats = torch.cuda.memory_stats(device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
@ -173,31 +184,43 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
self.cached_mem_free_total = None
|
||||
self.attention_slice_wrangler = None
|
||||
self.slicing_strategy_getter = None
|
||||
|
||||
def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
|
||||
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
||||
'''
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size),
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
self is the current CrossAttention module for which the callback is being invoked.
|
||||
attention_scores are the scores for attention
|
||||
suggested_attention_slice is a softmax(dim=-1) over attention_scores
|
||||
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If dim is >= 0, offset and slice_size specify the slice start and length.
|
||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
'''
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def cache_free_memory_count(self, device):
|
||||
self.cached_mem_free_total = get_mem_free_total(device)
|
||||
print("free cuda memory: ", self.cached_mem_free_total)
|
||||
|
||||
def clear_cached_free_memory_count(self):
|
||||
self.cached_mem_free_total = None
|
||||
|
||||
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
attention_scores = einsum('b i d, b j d -> b i j', q, k)
|
||||
# calculate attenion slice by taking the best scores for each latent pixel
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
if self.attention_slice_wrangler is not None:
|
||||
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
@ -240,17 +263,26 @@ class CrossAttention(nn.Module):
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||
slicing_strategy_getter = self.slicing_strategy_getter
|
||||
if slicing_strategy_getter is not None:
|
||||
(dim, slice_size) = slicing_strategy_getter(self)
|
||||
if dim is not None:
|
||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||
if dim == 0:
|
||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||
elif dim == 1:
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
|
||||
def get_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == 'mps':
|
||||
|
@ -38,4 +38,4 @@ git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||
git+https://github.com/invoke-ai/Real-ESRGAN.git#egg=realesrgan
|
||||
git+https://github.com/invoke-ai/GFPGAN.git#egg=gfpgan
|
||||
-e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
||||
git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
||||
|
@ -90,7 +90,12 @@ def main():
|
||||
safety_checker=opt.safety_checker,
|
||||
max_loaded_models=opt.max_loaded_models,
|
||||
)
|
||||
except (FileNotFoundError, IOError, KeyError) as e:
|
||||
except FileNotFoundError:
|
||||
print('** You appear to be missing configs/models.yaml')
|
||||
print('** You can either exit this script and run scripts/preload_models.py, or fix the problem now.')
|
||||
emergency_model_create(opt)
|
||||
sys.exit(-1)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
@ -485,6 +490,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
||||
command = '-h'
|
||||
return command, operation
|
||||
|
||||
|
||||
def add_weights_to_config(model_path:str, gen, opt, completer):
|
||||
print(f'>> Model import in process. Please enter the values needed to configure this model:')
|
||||
print()
|
||||
@ -581,7 +587,7 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak
|
||||
|
||||
try:
|
||||
print('>> Verifying that new model loads...')
|
||||
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
|
||||
gen.model_cache.add_model(model_name, new_config, clobber)
|
||||
assert gen.set_model(model_name) is not None, 'model failed to load'
|
||||
except AssertionError as e:
|
||||
print(f'** aborting **')
|
||||
@ -894,6 +900,36 @@ def write_commands(opt, file_path:str, outfilepath:str):
|
||||
f.write('\n'.join(commands))
|
||||
print(f'>> File {outfilepath} with commands created')
|
||||
|
||||
def emergency_model_create(opt:Args):
|
||||
completer = get_completer(opt)
|
||||
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
|
||||
completer.set_default_dir('.')
|
||||
valid_path = False
|
||||
while not valid_path:
|
||||
weights_file = input('Enter the path to a downloaded models file, or ^C to exit: ')
|
||||
valid_path = os.path.exists(weights_file)
|
||||
dir,basename = os.path.split(weights_file)
|
||||
|
||||
valid_name = False
|
||||
while not valid_name:
|
||||
name = input('Enter a short name for this model (no spaces): ')
|
||||
name = 'unnamed model' if len(name)==0 else name
|
||||
valid_name = ' ' not in name
|
||||
|
||||
description = input('Enter a description for this model: ')
|
||||
description = 'no description' if len(description)==0 else description
|
||||
|
||||
with open(opt.conf, 'w', encoding='utf-8') as f:
|
||||
f.write(f'{name}:\n')
|
||||
f.write(f' description: {description}\n')
|
||||
f.write(f' weights: {weights_file}\n')
|
||||
f.write(f' config: ./configs/stable-diffusion/v1-inference.yaml\n')
|
||||
f.write(f' width: 512\n')
|
||||
f.write(f' height: 512\n')
|
||||
f.write(f' default: true\n')
|
||||
print(f'Config file {opt.conf} is created. This script will now exit.')
|
||||
print(f'After restarting you may examine the entry with !models and edit it with !edit.')
|
||||
|
||||
######################################
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user