mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
🚧 post-rebase repair
This commit is contained in:
parent
adaa1c7c3e
commit
494936a8d2
4
.github/workflows/test-invoke-conda.yml
vendored
4
.github/workflows/test-invoke-conda.yml
vendored
@ -86,14 +86,14 @@ jobs:
|
|||||||
if: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/development' }}
|
if: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/development' }}
|
||||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> $GITHUB_ENV
|
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: run preload_models.py
|
- name: run configure_invokeai.py
|
||||||
id: run-preload-models
|
id: run-preload-models
|
||||||
run: |
|
run: |
|
||||||
if [ "${HAVE_SECRETS}" == true ] ; then
|
if [ "${HAVE_SECRETS}" == true ] ; then
|
||||||
mkdir -p ~/.huggingface
|
mkdir -p ~/.huggingface
|
||||||
echo -n '${{ secrets.HUGGINGFACE_TOKEN }}' > ~/.huggingface/token
|
echo -n '${{ secrets.HUGGINGFACE_TOKEN }}' > ~/.huggingface/token
|
||||||
fi
|
fi
|
||||||
python scripts/preload_models.py \
|
python scripts/configure_invokeai.py \
|
||||||
--no-interactive --yes \
|
--no-interactive --yes \
|
||||||
--full-precision # can't use fp16 weights without a GPU
|
--full-precision # can't use fp16 weights without a GPU
|
||||||
|
|
||||||
|
@ -1,33 +1,31 @@
|
|||||||
import eventlet
|
import base64
|
||||||
import glob
|
import glob
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import mimetypes
|
|
||||||
import traceback
|
import traceback
|
||||||
import math
|
from threading import Event
|
||||||
import io
|
from uuid import uuid4
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
|
|
||||||
from werkzeug.utils import secure_filename
|
import eventlet
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.Image import Image as ImageType
|
||||||
from flask import Flask, redirect, send_from_directory, request, make_response
|
from flask import Flask, redirect, send_from_directory, request, make_response
|
||||||
from flask_socketio import SocketIO
|
from flask_socketio import SocketIO
|
||||||
from PIL import Image, ImageOps
|
from werkzeug.utils import secure_filename
|
||||||
from PIL.Image import Image as ImageType
|
|
||||||
from uuid import uuid4
|
|
||||||
from threading import Event
|
|
||||||
|
|
||||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
|
||||||
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
|
||||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
|
||||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
|
||||||
from ldm.invoke.generator.inpaint import infill_methods
|
|
||||||
|
|
||||||
from backend.modules.parameters import parameters_to_command
|
|
||||||
from backend.modules.get_canvas_generation_mode import (
|
from backend.modules.get_canvas_generation_mode import (
|
||||||
get_canvas_generation_mode,
|
get_canvas_generation_mode,
|
||||||
)
|
)
|
||||||
|
from backend.modules.parameters import parameters_to_command
|
||||||
|
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||||
|
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
||||||
|
from ldm.invoke.generator.inpaint import infill_methods
|
||||||
|
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||||
|
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
||||||
|
|
||||||
# Loading Arguments
|
# Loading Arguments
|
||||||
opt = Args()
|
opt = Args()
|
||||||
@ -251,7 +249,7 @@ class InvokeAIWebServer:
|
|||||||
return candidate
|
return candidate
|
||||||
assert "Frontend files cannot be found. Cannot continue"
|
assert "Frontend files cannot be found. Cannot continue"
|
||||||
|
|
||||||
|
|
||||||
def setup_app(self):
|
def setup_app(self):
|
||||||
self.result_url = "outputs/"
|
self.result_url = "outputs/"
|
||||||
self.init_image_url = "outputs/init-images/"
|
self.init_image_url = "outputs/init-images/"
|
||||||
@ -776,10 +774,10 @@ class InvokeAIWebServer:
|
|||||||
).convert("RGBA")
|
).convert("RGBA")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
The outpaint image and mask are pre-cropped by the UI, so the bounding box we pass
|
The outpaint image and mask are pre-cropped by the UI, so the bounding box we pass
|
||||||
to the generator should be:
|
to the generator should be:
|
||||||
{
|
{
|
||||||
"x": 0,
|
"x": 0,
|
||||||
"y": 0,
|
"y": 0,
|
||||||
"width": original_bounding_box["width"],
|
"width": original_bounding_box["width"],
|
||||||
"height": original_bounding_box["height"]
|
"height": original_bounding_box["height"]
|
||||||
@ -799,7 +797,7 @@ class InvokeAIWebServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Apply the mask to the init image, creating a "mask" image with
|
Apply the mask to the init image, creating a "mask" image with
|
||||||
transparency where inpainting should occur. This is the kind of
|
transparency where inpainting should occur. This is the kind of
|
||||||
mask that prompt2image() needs.
|
mask that prompt2image() needs.
|
||||||
"""
|
"""
|
||||||
|
@ -40,15 +40,6 @@ dependencies:
|
|||||||
- torch-fidelity==0.3.0
|
- torch-fidelity==0.3.0
|
||||||
- torchmetrics==0.7.0
|
- torchmetrics==0.7.0
|
||||||
- transformers==4.21.3
|
- transformers==4.21.3
|
||||||
- diffusers~=0.7
|
|
||||||
- torchmetrics==0.7.0
|
|
||||||
- flask==2.1.3
|
|
||||||
- flask_socketio==5.3.0
|
|
||||||
- flask_cors==3.0.10
|
|
||||||
- dependency_injector==4.40.0
|
|
||||||
- eventlet
|
|
||||||
- getpass_asterisk
|
|
||||||
- kornia==0.6.0
|
|
||||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||||
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
||||||
|
@ -236,7 +236,7 @@ class Generate:
|
|||||||
except Exception:
|
except Exception:
|
||||||
print('** An error was encountered while installing the safety checker:')
|
print('** An error was encountered while installing the safety checker:')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
def prompt2png(self, prompt, outdir, **kwargs):
|
def prompt2png(self, prompt, outdir, **kwargs):
|
||||||
"""
|
"""
|
||||||
Takes a prompt and an output directory, writes out the requested number
|
Takes a prompt and an output directory, writes out the requested number
|
||||||
@ -330,7 +330,7 @@ class Generate:
|
|||||||
infill_method = infill_methods[0], # The infill method to use
|
infill_method = infill_methods[0], # The infill method to use
|
||||||
force_outpaint: bool = False,
|
force_outpaint: bool = False,
|
||||||
enable_image_debugging = False,
|
enable_image_debugging = False,
|
||||||
|
|
||||||
**args,
|
**args,
|
||||||
): # eat up additional cruft
|
): # eat up additional cruft
|
||||||
"""
|
"""
|
||||||
@ -373,7 +373,7 @@ class Generate:
|
|||||||
def process_image(image,seed):
|
def process_image(image,seed):
|
||||||
image.save(f{'images/seed.png'})
|
image.save(f{'images/seed.png'})
|
||||||
|
|
||||||
The code used to save images to a directory can be found in ldm/invoke/pngwriter.py.
|
The code used to save images to a directory can be found in ldm/invoke/pngwriter.py.
|
||||||
It contains code to create the requested output directory, select a unique informative
|
It contains code to create the requested output directory, select a unique informative
|
||||||
name for each image, and write the prompt into the PNG metadata.
|
name for each image, and write the prompt into the PNG metadata.
|
||||||
"""
|
"""
|
||||||
@ -593,7 +593,7 @@ class Generate:
|
|||||||
seed = opt.seed or args.seed
|
seed = opt.seed or args.seed
|
||||||
if seed is None or seed < 0:
|
if seed is None or seed < 0:
|
||||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
|
|
||||||
prompt = opt.prompt or args.prompt or ''
|
prompt = opt.prompt or args.prompt or ''
|
||||||
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
|
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
|
||||||
|
|
||||||
@ -645,7 +645,7 @@ class Generate:
|
|||||||
|
|
||||||
opt.seed = seed
|
opt.seed = seed
|
||||||
opt.prompt = prompt
|
opt.prompt = prompt
|
||||||
|
|
||||||
if len(extend_instructions) > 0:
|
if len(extend_instructions) > 0:
|
||||||
restorer = Outcrop(image,self,)
|
restorer = Outcrop(image,self,)
|
||||||
return restorer.process (
|
return restorer.process (
|
||||||
@ -687,7 +687,7 @@ class Generate:
|
|||||||
image_callback = callback,
|
image_callback = callback,
|
||||||
prefix = prefix
|
prefix = prefix
|
||||||
)
|
)
|
||||||
|
|
||||||
elif tool is None:
|
elif tool is None:
|
||||||
print(f'* please provide at least one postprocessing option, such as -G or -U')
|
print(f'* please provide at least one postprocessing option, such as -G or -U')
|
||||||
return None
|
return None
|
||||||
@ -710,13 +710,13 @@ class Generate:
|
|||||||
|
|
||||||
if embiggen is not None:
|
if embiggen is not None:
|
||||||
return self._make_embiggen()
|
return self._make_embiggen()
|
||||||
|
|
||||||
if inpainting_model_in_use:
|
if inpainting_model_in_use:
|
||||||
return self._make_omnibus()
|
return self._make_omnibus()
|
||||||
|
|
||||||
if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
|
if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
|
||||||
return self._make_inpaint()
|
return self._make_inpaint()
|
||||||
|
|
||||||
if init_image is not None:
|
if init_image is not None:
|
||||||
return self._make_img2img()
|
return self._make_img2img()
|
||||||
|
|
||||||
@ -747,7 +747,7 @@ class Generate:
|
|||||||
if self._has_transparency(image):
|
if self._has_transparency(image):
|
||||||
self._transparency_check_and_warning(image, mask, force_outpaint)
|
self._transparency_check_and_warning(image, mask, force_outpaint)
|
||||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||||
|
|
||||||
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||||
self.size_matters = False
|
self.size_matters = False
|
||||||
@ -763,7 +763,7 @@ class Generate:
|
|||||||
|
|
||||||
if invert_mask:
|
if invert_mask:
|
||||||
init_mask = ImageOps.invert(init_mask)
|
init_mask = ImageOps.invert(init_mask)
|
||||||
|
|
||||||
return init_image,init_mask
|
return init_image,init_mask
|
||||||
|
|
||||||
# lots o' repeated code here! Turn into a make_func()
|
# lots o' repeated code here! Turn into a make_func()
|
||||||
@ -822,7 +822,7 @@ class Generate:
|
|||||||
self.set_model(self.model_name)
|
self.set_model(self.model_name)
|
||||||
|
|
||||||
def set_model(self,model_name):
|
def set_model(self,model_name):
|
||||||
"""
|
"""
|
||||||
Given the name of a model defined in models.yaml, will load and initialize it
|
Given the name of a model defined in models.yaml, will load and initialize it
|
||||||
and return the model object. Previously-used models will be cached.
|
and return the model object. Previously-used models will be cached.
|
||||||
"""
|
"""
|
||||||
@ -834,7 +834,7 @@ class Generate:
|
|||||||
if not cache.valid_model(model_name):
|
if not cache.valid_model(model_name):
|
||||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
cache.print_vram_usage()
|
cache.print_vram_usage()
|
||||||
|
|
||||||
# have to get rid of all references to model in order
|
# have to get rid of all references to model in order
|
||||||
@ -843,7 +843,7 @@ class Generate:
|
|||||||
self.sampler = None
|
self.sampler = None
|
||||||
self.generators = {}
|
self.generators = {}
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
model_data = cache.get_model(model_name)
|
model_data = cache.get_model(model_name)
|
||||||
if model_data is None: # restore previous
|
if model_data is None: # restore previous
|
||||||
model_data = cache.get_model(self.model_name)
|
model_data = cache.get_model(self.model_name)
|
||||||
@ -856,7 +856,7 @@ class Generate:
|
|||||||
|
|
||||||
# uncache generators so they pick up new models
|
# uncache generators so they pick up new models
|
||||||
self.generators = {}
|
self.generators = {}
|
||||||
|
|
||||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
self.model.embedding_manager.load(
|
self.model.embedding_manager.load(
|
||||||
@ -905,7 +905,7 @@ class Generate:
|
|||||||
image_callback = None,
|
image_callback = None,
|
||||||
prefix = None,
|
prefix = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
for r in image_list:
|
for r in image_list:
|
||||||
image, seed = r
|
image, seed = r
|
||||||
try:
|
try:
|
||||||
@ -915,7 +915,7 @@ class Generate:
|
|||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
print('>> GFPGAN not found. Face restoration is disabled.')
|
print('>> GFPGAN not found. Face restoration is disabled.')
|
||||||
else:
|
else:
|
||||||
image = self.gfpgan.process(image, strength, seed)
|
image = self.gfpgan.process(image, strength, seed)
|
||||||
if facetool == 'codeformer':
|
if facetool == 'codeformer':
|
||||||
if self.codeformer is None:
|
if self.codeformer is None:
|
||||||
print('>> CodeFormer not found. Face restoration is disabled.')
|
print('>> CodeFormer not found. Face restoration is disabled.')
|
||||||
|
@ -9,9 +9,10 @@ import os.path as osp
|
|||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageFilter
|
from PIL import Image, ImageFilter, ImageChops
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
@ -169,7 +170,7 @@ class Generator:
|
|||||||
# Blur the mask out (into init image) by specified amount
|
# Blur the mask out (into init image) by specified amount
|
||||||
if mask_blur_radius > 0:
|
if mask_blur_radius > 0:
|
||||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||||
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
nmd = cv2.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||||
pmd = Image.fromarray(nmd, mode='L')
|
pmd = Image.fromarray(nmd, mode='L')
|
||||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||||
else:
|
else:
|
||||||
@ -181,8 +182,6 @@ class Generator:
|
|||||||
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
||||||
return matched_result
|
return matched_result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def sample_to_lowres_estimated_image(self,samples):
|
def sample_to_lowres_estimated_image(self,samples):
|
||||||
# origingally adapted from code by @erucipe and @keturn here:
|
# origingally adapted from code by @erucipe and @keturn here:
|
||||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||||
|
@ -21,9 +21,6 @@ from typing import Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import textwrap
|
|
||||||
import contextlib
|
|
||||||
from typing import Union
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.errors import ConfigAttributeError
|
from omegaconf.errors import ConfigAttributeError
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
@ -99,7 +96,7 @@ class ModelCache(object):
|
|||||||
assert self.current_model,'** FATAL: no current model to restore to'
|
assert self.current_model,'** FATAL: no current model to restore to'
|
||||||
print(f'** restoring {self.current_model}')
|
print(f'** restoring {self.current_model}')
|
||||||
self.get_model(self.current_model)
|
self.get_model(self.current_model)
|
||||||
return None
|
return
|
||||||
|
|
||||||
self.current_model = model_name
|
self.current_model = model_name
|
||||||
self._push_newest_model(model_name)
|
self._push_newest_model(model_name)
|
||||||
@ -219,7 +216,7 @@ class ModelCache(object):
|
|||||||
if model_format == 'ckpt':
|
if model_format == 'ckpt':
|
||||||
weights = mconfig.weights
|
weights = mconfig.weights
|
||||||
print(f'>> Loading {model_name} from {weights}')
|
print(f'>> Loading {model_name} from {weights}')
|
||||||
model, width, height, model_hash = self._load_ckpt_model(mconfig)
|
model, width, height, model_hash = self._load_ckpt_model(model_name, mconfig)
|
||||||
elif model_format == 'diffusers':
|
elif model_format == 'diffusers':
|
||||||
model, width, height, model_hash = self._load_diffusers_model(mconfig)
|
model, width, height, model_hash = self._load_diffusers_model(mconfig)
|
||||||
else:
|
else:
|
||||||
@ -237,10 +234,10 @@ class ModelCache(object):
|
|||||||
)
|
)
|
||||||
return model, width, height, model_hash
|
return model, width, height, model_hash
|
||||||
|
|
||||||
def _load_ckpt_model(self, mconfig):
|
def _load_ckpt_model(self, model_name, mconfig):
|
||||||
config = mconfig.config
|
config = mconfig.config
|
||||||
weights = mconfig.weights
|
weights = mconfig.weights
|
||||||
vae = mconfig.get('vae', None)
|
vae = mconfig.get('vae')
|
||||||
width = mconfig.width
|
width = mconfig.width
|
||||||
height = mconfig.height
|
height = mconfig.height
|
||||||
|
|
||||||
@ -249,10 +246,22 @@ class ModelCache(object):
|
|||||||
if not os.path.isabs(weights):
|
if not os.path.isabs(weights):
|
||||||
weights = os.path.normpath(os.path.join(Globals.root,weights))
|
weights = os.path.normpath(os.path.join(Globals.root,weights))
|
||||||
# scan model
|
# scan model
|
||||||
self._scan_model(model_name, weights)
|
self.scan_model(model_name, weights)
|
||||||
|
|
||||||
c = OmegaConf.load(config)
|
print(f'>> Loading {model_name} from {weights}')
|
||||||
with open(weights, 'rb') as f:
|
|
||||||
|
# for usage statistics
|
||||||
|
if self._has_cuda():
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
# this does the work
|
||||||
|
if not os.path.isabs(config):
|
||||||
|
config = os.path.join(Globals.root,config)
|
||||||
|
omega_config = OmegaConf.load(config)
|
||||||
|
with open(weights,'rb') as f:
|
||||||
weight_bytes = f.read()
|
weight_bytes = f.read()
|
||||||
model_hash = self._cached_sha256(weights, weight_bytes)
|
model_hash = self._cached_sha256(weights, weight_bytes)
|
||||||
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||||
@ -289,6 +298,18 @@ class ModelCache(object):
|
|||||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||||
module._orig_padding_mode = module.padding_mode
|
module._orig_padding_mode = module.padding_mode
|
||||||
|
|
||||||
|
# usage statistics
|
||||||
|
toc = time.time()
|
||||||
|
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||||
|
|
||||||
|
if self._has_cuda():
|
||||||
|
print(
|
||||||
|
'>> Max VRAM used to load the model:',
|
||||||
|
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||||
|
'\n>> Current VRAM usage:'
|
||||||
|
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||||
|
)
|
||||||
|
|
||||||
return model, width, height, model_hash
|
return model, width, height, model_hash
|
||||||
|
|
||||||
def _load_diffusers_model(self, mconfig):
|
def _load_diffusers_model(self, mconfig):
|
||||||
@ -308,6 +329,8 @@ class ModelCache(object):
|
|||||||
|
|
||||||
print(f'>> Loading diffusers model from {name_or_path}')
|
print(f'>> Loading diffusers model from {name_or_path}')
|
||||||
|
|
||||||
|
# TODO: scan weights maybe?
|
||||||
|
|
||||||
if self.precision == 'float16':
|
if self.precision == 'float16':
|
||||||
print(' | Using faster float16 precision')
|
print(' | Using faster float16 precision')
|
||||||
pipeline_args.update(revision="fp16", torch_dtype=torch.float16)
|
pipeline_args.update(revision="fp16", torch_dtype=torch.float16)
|
||||||
@ -342,7 +365,7 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Model config must specify either repo_name or path.")
|
raise ValueError("Model config must specify either repo_name or path.")
|
||||||
|
|
||||||
def offload_model(self, model_name:str):
|
def offload_model(self, model_name:str) -> None:
|
||||||
'''
|
'''
|
||||||
Offload the indicated model to CPU. Will call
|
Offload the indicated model to CPU. Will call
|
||||||
_make_cache_room() to free space if needed.
|
_make_cache_room() to free space if needed.
|
||||||
|
@ -248,33 +248,33 @@ def inject_attention_function(unet, context: Context):
|
|||||||
|
|
||||||
cross_attention_modules = [(name, module) for (name, module) in unet.named_modules()
|
cross_attention_modules = [(name, module) for (name, module) in unet.named_modules()
|
||||||
if type(module).__name__ == "CrossAttention"]
|
if type(module).__name__ == "CrossAttention"]
|
||||||
for identifier, module in cross_attention_modules:
|
for identifier, module in cross_attention_modules:
|
||||||
module.identifier = identifier
|
module.identifier = identifier
|
||||||
try:
|
try:
|
||||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||||
module.set_slicing_strategy_getter(
|
module.set_slicing_strategy_getter(
|
||||||
lambda module: context.get_slicing_strategy(identifier)
|
lambda module: context.get_slicing_strategy(identifier)
|
||||||
)
|
)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||||
warnings.warn(f"TODO: implement for {type(module)}") # TODO
|
warnings.warn(f"TODO: implement for {type(module)}") # TODO
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def remove_attention_function(unet):
|
def remove_attention_function(unet):
|
||||||
cross_attention_modules = [module for (_, module) in unet.named_modules()
|
cross_attention_modules = [module for (_, module) in unet.named_modules()
|
||||||
if type(module).__name__ == "CrossAttention"]
|
if type(module).__name__ == "CrossAttention"]
|
||||||
for module in cross_attention_modules:
|
for module in cross_attention_modules:
|
||||||
try:
|
try:
|
||||||
# clear wrangler callback
|
# clear wrangler callback
|
||||||
module.set_attention_slice_wrangler(None)
|
module.set_attention_slice_wrangler(None)
|
||||||
module.set_slicing_strategy_getter(None)
|
module.set_slicing_strategy_getter(None)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||||
warnings.warn(f"TODO: implement for {type(module)}") # TODO
|
warnings.warn(f"TODO: implement for {type(module)}") # TODO
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def is_attribute_error_about(error: AttributeError, attribute: str):
|
def is_attribute_error_about(error: AttributeError, attribute: str):
|
||||||
|
@ -34,6 +34,12 @@ warnings.filterwarnings('ignore')
|
|||||||
import torch
|
import torch
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ldm.invoke.model_cache import ModelCache
|
||||||
|
except ImportError:
|
||||||
|
sys.path.append('.')
|
||||||
|
from ldm.invoke.model_cache import ModelCache
|
||||||
|
|
||||||
#--------------------------globals-----------------------
|
#--------------------------globals-----------------------
|
||||||
Model_dir = 'models'
|
Model_dir = 'models'
|
||||||
Weights_dir = 'ldm/stable-diffusion-v1/'
|
Weights_dir = 'ldm/stable-diffusion-v1/'
|
||||||
@ -267,6 +273,19 @@ def download_weight_datasets(models:dict, access_token:str):
|
|||||||
print(f'Successfully installed {keys}')
|
print(f'Successfully installed {keys}')
|
||||||
return successful
|
return successful
|
||||||
|
|
||||||
|
#---------------------------------------------
|
||||||
|
def is_huggingface_authenticated():
|
||||||
|
# huggingface_hub 0.10 API isn't great for this, it could be OSError, ValueError,
|
||||||
|
# maybe other things, not all end-user-friendly.
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
response = hf_whoami()
|
||||||
|
if response.get('id') is not None:
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
|
def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
|
||||||
model_dest = os.path.join(model_dir, model_name)
|
model_dest = os.path.join(model_dir, model_name)
|
||||||
@ -749,6 +768,12 @@ def main():
|
|||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
default=True,
|
default=True,
|
||||||
help='run in interactive mode (default)')
|
help='run in interactive mode (default)')
|
||||||
|
parser.add_argument('--full-precision',
|
||||||
|
dest='full_precision',
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help='use 32-bit weights instead of faster 16-bit weights')
|
||||||
parser.add_argument('--yes','-y',
|
parser.add_argument('--yes','-y',
|
||||||
dest='yes_to_all',
|
dest='yes_to_all',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
|
Loading…
Reference in New Issue
Block a user