mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into fix/controlnet_cfg_inj_cond
This commit is contained in:
@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DiffusionPipeline
|
||||
from tqdm import trange
|
||||
from typing import Callable, List, Iterator, Optional, Type
|
||||
from typing import Callable, List, Iterator, Optional, Type, Union
|
||||
from dataclasses import dataclass, field
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
@ -178,7 +178,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
# ------------------------------------
|
||||
class Img2Img(InvokeAIGenerator):
|
||||
def generate(self,
|
||||
init_image: Image.Image | torch.FloatTensor,
|
||||
init_image: Union[Image.Image, torch.FloatTensor],
|
||||
strength: float=0.75,
|
||||
**keyword_args
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
@ -195,7 +195,7 @@ class Img2Img(InvokeAIGenerator):
|
||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||
class Inpaint(Img2Img):
|
||||
def generate(self,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
mask_image: Union[Image.Image, torch.FloatTensor],
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
@ -570,28 +570,16 @@ class Generator:
|
||||
device = self.model.device
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
if self.use_mps_noise or device.type == "mps":
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
],
|
||||
dtype=self.torch_dtype(),
|
||||
device="cpu",
|
||||
).to(device)
|
||||
else:
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
],
|
||||
dtype=self.torch_dtype(),
|
||||
device=device,
|
||||
)
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
],
|
||||
dtype=self.torch_dtype(),
|
||||
device=device,
|
||||
)
|
||||
if self.perlin > 0.0:
|
||||
perlin_noise = self.get_perlin_noise(
|
||||
width // self.downsampling_factor, height // self.downsampling_factor
|
||||
|
@ -88,10 +88,7 @@ class Img2Img(Generator):
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
if device.type == "mps":
|
||||
x = torch.randn_like(like, device="cpu").to(device)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device)
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
||||
|
@ -4,11 +4,10 @@ invokeai.backend.generator.inpaint descends from .generator
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple, Union, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
@ -76,7 +75,7 @@ class Inpaint(Img2Img):
|
||||
return im_patched
|
||||
|
||||
def tile_fill_missing(
|
||||
self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
@ -203,8 +202,8 @@ class Inpaint(Img2Img):
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image: Image.Image | torch.FloatTensor,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
init_image: Union[Image.Image, torch.FloatTensor],
|
||||
mask_image: Union[Image.Image, torch.FloatTensor],
|
||||
strength: float,
|
||||
mask_blur_radius: int = 8,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
|
@ -45,6 +45,7 @@ from invokeai.app.services.config import (
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.widgets import (
|
||||
SingleSelectColumns,
|
||||
CenteredButtonPress,
|
||||
IntTitleSlider,
|
||||
set_min_terminal_size,
|
||||
@ -76,7 +77,7 @@ Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
PRECISION_CHOICES = ['auto','float16','float32','autocast']
|
||||
PRECISION_CHOICES = ['auto','float16','float32']
|
||||
|
||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
@ -359,9 +360,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
label = """If you have an account at HuggingFace you may optionally paste your access token here
|
||||
to allow InvokeAI to download restricted styles & subjects from the "Concept Library". See https://huggingface.co/settings/tokens.
|
||||
"""
|
||||
label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
|
||||
for line in textwrap.wrap(label,width=window_width-6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
@ -423,6 +422,7 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
|
||||
)
|
||||
self.precision = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
columns = 2,
|
||||
name="Precision",
|
||||
values=PRECISION_CHOICES,
|
||||
value=PRECISION_CHOICES.index(precision),
|
||||
@ -430,13 +430,13 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
|
||||
max_height=len(PRECISION_CHOICES) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.max_loaded_models = self.add_widget_intelligent(
|
||||
self.max_cache_size = self.add_widget_intelligent(
|
||||
IntTitleSlider,
|
||||
name="Number of models to cache in CPU memory (each will use 2-4 GB!)",
|
||||
value=old_opts.max_loaded_models,
|
||||
out_of=10,
|
||||
lowest=1,
|
||||
begin_entry_at=4,
|
||||
name="Size of the RAM cache used for fast model switching (GB)",
|
||||
value=old_opts.max_cache_size,
|
||||
out_of=20,
|
||||
lowest=3,
|
||||
begin_entry_at=6,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
@ -539,7 +539,7 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
"outdir",
|
||||
"nsfw_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"max_cache_size",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
]:
|
||||
@ -555,9 +555,6 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
|
||||
# widget library workaround to make max_loaded_models an int rather than a float
|
||||
new_opts.max_loaded_models = int(new_opts.max_loaded_models)
|
||||
|
||||
return new_opts
|
||||
|
||||
|
||||
|
@ -4,6 +4,8 @@ import argparse
|
||||
import shlex
|
||||
from argparse import ArgumentParser
|
||||
|
||||
# note that this includes both old sampler names and new scheduler names
|
||||
# in order to be able to parse both 2.0 and 3.0-pre-nodes versions of invokeai.init
|
||||
SAMPLER_CHOICES = [
|
||||
"ddim",
|
||||
"ddpm",
|
||||
@ -27,6 +29,15 @@ SAMPLER_CHOICES = [
|
||||
"dpmpp_sde",
|
||||
"dpmpp_sde_k",
|
||||
"unipc",
|
||||
"k_dpm_2_a",
|
||||
"k_dpm_2",
|
||||
"k_dpmpp_2_a",
|
||||
"k_dpmpp_2",
|
||||
"k_euler_a",
|
||||
"k_euler",
|
||||
"k_heun",
|
||||
"k_lms",
|
||||
"plms",
|
||||
]
|
||||
|
||||
PRECISION_CHOICES = [
|
||||
|
@ -3,7 +3,6 @@ Migrate the models directory and models.yaml file from an existing
|
||||
InvokeAI 2.3 installation to 3.0.0.
|
||||
'''
|
||||
|
||||
import io
|
||||
import os
|
||||
import argparse
|
||||
import shutil
|
||||
@ -28,9 +27,10 @@ from transformers import (
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager
|
||||
from invokeai.backend.model_management.model_probe import (
|
||||
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType, ModelProbeInfo
|
||||
ModelProbe, ModelType, BaseModelType, ModelProbeInfo
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
@ -47,48 +47,27 @@ class ModelPaths:
|
||||
|
||||
class MigrateTo3(object):
|
||||
def __init__(self,
|
||||
root_directory: Path,
|
||||
dest_models: Path,
|
||||
yaml_file: io.TextIOBase,
|
||||
from_root: Path,
|
||||
to_models: Path,
|
||||
model_manager: ModelManager,
|
||||
src_paths: ModelPaths,
|
||||
):
|
||||
self.root_directory = root_directory
|
||||
self.dest_models = dest_models
|
||||
self.dest_yaml = yaml_file
|
||||
self.model_names = set()
|
||||
self.root_directory = from_root
|
||||
self.dest_models = to_models
|
||||
self.mgr = model_manager
|
||||
self.src_paths = src_paths
|
||||
|
||||
self._initialize_yaml()
|
||||
|
||||
def _initialize_yaml(self):
|
||||
self.dest_yaml.write(
|
||||
yaml.dump(
|
||||
{
|
||||
'__metadata__':
|
||||
@classmethod
|
||||
def initialize_yaml(cls, yaml_file: Path):
|
||||
with open(yaml_file, 'w') as file:
|
||||
file.write(
|
||||
yaml.dump(
|
||||
{
|
||||
'version':'3.0.0'}
|
||||
}
|
||||
'__metadata__': {'version':'3.0.0'}
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def unique_name(self,name,info)->str:
|
||||
'''
|
||||
Create a unique name for a model for use within models.yaml.
|
||||
'''
|
||||
done = False
|
||||
key = ModelManager.create_key(name,info.base_type,info.model_type)
|
||||
unique_name = key
|
||||
counter = 1
|
||||
while not done:
|
||||
if unique_name in self.model_names:
|
||||
unique_name = f'{key}-{counter:0>2d}'
|
||||
counter += 1
|
||||
else:
|
||||
done = True
|
||||
self.model_names.add(unique_name)
|
||||
name,_,_ = ModelManager.parse_key(unique_name)
|
||||
return name
|
||||
|
||||
def create_directory_structure(self):
|
||||
'''
|
||||
Create the basic directory structure for the models folder.
|
||||
@ -136,23 +115,8 @@ class MigrateTo3(object):
|
||||
that looks like a model, and copy the model into the
|
||||
appropriate location within the destination models directory.
|
||||
'''
|
||||
directories_scanned = set()
|
||||
for root, dirs, files in os.walk(src_dir):
|
||||
for f in files:
|
||||
# hack - don't copy raw learned_embeds.bin, let them
|
||||
# be copied as part of a tree copy operation
|
||||
if f == 'learned_embeds.bin':
|
||||
continue
|
||||
try:
|
||||
model = Path(root,f)
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / f
|
||||
self.copy_file(model, dest)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
for d in dirs:
|
||||
try:
|
||||
model = Path(root,d)
|
||||
@ -161,6 +125,29 @@ class MigrateTo3(object):
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / model.name
|
||||
self.copy_dir(model, dest)
|
||||
directories_scanned.add(model)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
for f in files:
|
||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||
# let them be copied as part of a tree copy operation
|
||||
try:
|
||||
if f in {'learned_embeds.bin','pytorch_lora_weights.bin'}:
|
||||
continue
|
||||
model = Path(root,f)
|
||||
if model.parent in directories_scanned:
|
||||
continue
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / f
|
||||
self.copy_file(model, dest)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -219,11 +206,12 @@ class MigrateTo3(object):
|
||||
repo_id = 'openai/clip-vit-large-patch14'
|
||||
self._migrate_pretrained(CLIPTokenizer,
|
||||
repo_id= repo_id,
|
||||
dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer',
|
||||
dest= target_dir / 'clip-vit-large-patch14',
|
||||
**kwargs)
|
||||
self._migrate_pretrained(CLIPTextModel,
|
||||
repo_id = repo_id,
|
||||
dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder',
|
||||
dest = target_dir / 'clip-vit-large-patch14',
|
||||
force = True,
|
||||
**kwargs)
|
||||
|
||||
# sd-2
|
||||
@ -262,46 +250,24 @@ class MigrateTo3(object):
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
def write_yaml(self, model_name: str, path:Path, info:ModelProbeInfo, **kwargs):
|
||||
'''
|
||||
Write a stanza for a moved model into the new models.yaml file.
|
||||
'''
|
||||
name = self.unique_name(model_name, info)
|
||||
stanza = {
|
||||
f'{info.base_type.value}/{info.model_type.value}/{name}': {
|
||||
'name': model_name,
|
||||
'path': str(path),
|
||||
'description': f'A {info.base_type.value} {info.model_type.value} model',
|
||||
'format': info.format,
|
||||
'image_size': info.image_size,
|
||||
'base': info.base_type.value,
|
||||
'variant': info.variant_type.value,
|
||||
'prediction_type': info.prediction_type.value,
|
||||
'upcast_attention': info.prediction_type == SchedulerPredictionType.VPrediction,
|
||||
**kwargs,
|
||||
}
|
||||
}
|
||||
self.dest_yaml.write(yaml.dump(stanza))
|
||||
self.dest_yaml.flush()
|
||||
|
||||
def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
|
||||
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
||||
|
||||
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs):
|
||||
if dest.exists():
|
||||
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs):
|
||||
if dest.exists() and not force:
|
||||
logger.info(f'Skipping existing {dest}')
|
||||
return
|
||||
model = model_class.from_pretrained(repo_id, **kwargs)
|
||||
self._save_pretrained(model, dest)
|
||||
self._save_pretrained(model, dest, overwrite=force)
|
||||
|
||||
def _save_pretrained(self, model, dest: Path):
|
||||
if dest.exists():
|
||||
logger.info(f'Skipping existing {dest}')
|
||||
return
|
||||
def _save_pretrained(self, model, dest: Path, overwrite: bool=False):
|
||||
model_name = dest.name
|
||||
download_path = dest.with_name(f'{model_name}.downloading')
|
||||
model.save_pretrained(download_path, safe_serialization=True)
|
||||
download_path.replace(dest)
|
||||
if overwrite:
|
||||
model.save_pretrained(dest, safe_serialization=True)
|
||||
else:
|
||||
download_path = dest.with_name(f'{model_name}.downloading')
|
||||
model.save_pretrained(download_path, safe_serialization=True)
|
||||
download_path.replace(dest)
|
||||
|
||||
def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
|
||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
|
||||
@ -327,6 +293,7 @@ class MigrateTo3(object):
|
||||
elif repo_id := vae.get('repo_id'):
|
||||
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
|
||||
vae_path = 'models/core/convert/sd-vae-ft-mse'
|
||||
return vae_path
|
||||
else:
|
||||
vae_path = self._download_vae(repo_id, vae.get('subfolder'))
|
||||
|
||||
@ -339,7 +306,10 @@ class MigrateTo3(object):
|
||||
info = ModelProbe().heuristic_probe(vae_path)
|
||||
dest = self._model_probe_to_path(info) / vae_path.name
|
||||
if not dest.exists():
|
||||
self.copy_dir(vae_path,dest)
|
||||
if vae_path.is_dir():
|
||||
self.copy_dir(vae_path,dest)
|
||||
else:
|
||||
self.copy_file(vae_path,dest)
|
||||
vae_path = dest
|
||||
|
||||
if vae_path.is_relative_to(self.dest_models):
|
||||
@ -348,7 +318,7 @@ class MigrateTo3(object):
|
||||
else:
|
||||
return vae_path
|
||||
|
||||
def migrate_repo_id(self, repo_id: str, model_name :str=None, **extra_config):
|
||||
def migrate_repo_id(self, repo_id: str, model_name: str=None, **extra_config):
|
||||
'''
|
||||
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
||||
'''
|
||||
@ -380,11 +350,15 @@ class MigrateTo3(object):
|
||||
if not info:
|
||||
return
|
||||
|
||||
dest = self._model_probe_to_path(info) / repo_name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
|
||||
return
|
||||
|
||||
dest = self._model_probe_to_path(info) / model_name
|
||||
self._save_pretrained(pipeline, dest)
|
||||
|
||||
rel_path = Path('models',dest.relative_to(dest_dir))
|
||||
self.write_yaml(model_name, path=rel_path, info=info, **extra_config)
|
||||
self._add_model(model_name, info, rel_path, **extra_config)
|
||||
|
||||
def migrate_path(self, location: Path, model_name: str=None, **extra_config):
|
||||
'''
|
||||
@ -394,20 +368,49 @@ class MigrateTo3(object):
|
||||
# handle relative paths
|
||||
dest_dir = self.dest_models
|
||||
location = self.root_directory / location
|
||||
model_name = model_name or location.stem
|
||||
|
||||
info = ModelProbe().heuristic_probe(location)
|
||||
if not info:
|
||||
return
|
||||
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
|
||||
return
|
||||
|
||||
# uh oh, weights is in the old models directory - move it into the new one
|
||||
if Path(location).is_relative_to(self.src_paths.models):
|
||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
|
||||
self.copy_dir(location,dest)
|
||||
if location.is_dir():
|
||||
self.copy_dir(location,dest)
|
||||
else:
|
||||
self.copy_file(location,dest)
|
||||
location = Path('models', info.base_type.value, info.model_type.value, location.name)
|
||||
model_name = model_name or location.stem
|
||||
model_name = self.unique_name(model_name, info)
|
||||
self.write_yaml(model_name, path=location, info=info, **extra_config)
|
||||
|
||||
self._add_model(model_name, info, location, **extra_config)
|
||||
|
||||
def _add_model(self,
|
||||
model_name: str,
|
||||
info: ModelProbeInfo,
|
||||
location: Path,
|
||||
**extra_config):
|
||||
if info.model_type != ModelType.Main:
|
||||
return
|
||||
|
||||
self.mgr.add_model(
|
||||
model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
clobber = True,
|
||||
model_attributes = {
|
||||
'path': str(location),
|
||||
'description': f'A {info.base_type.value} {info.model_type.value} model',
|
||||
'model_format': info.format,
|
||||
'variant': info.variant_type.value,
|
||||
**extra_config,
|
||||
}
|
||||
)
|
||||
|
||||
def migrate_defined_models(self):
|
||||
'''
|
||||
Migrate models defined in models.yaml
|
||||
@ -429,6 +432,9 @@ class MigrateTo3(object):
|
||||
|
||||
if config := stanza.get('config'):
|
||||
passthru_args['config'] = config
|
||||
|
||||
if description:= stanza.get('description'):
|
||||
passthru_args['description'] = description
|
||||
|
||||
if repo_id := stanza.get('repo_id'):
|
||||
logger.info(f'Migrating diffusers model {model_name}')
|
||||
@ -509,31 +515,50 @@ def get_legacy_embeddings(root: Path) -> ModelPaths:
|
||||
return _parse_legacy_yamlfile(root, path)
|
||||
|
||||
def do_migrate(src_directory: Path, dest_directory: Path):
|
||||
"""
|
||||
Migrate models from src to dest InvokeAI root directories
|
||||
"""
|
||||
config_file = dest_directory / 'configs' / 'models.yaml.3'
|
||||
dest_models = dest_directory / 'models.3'
|
||||
|
||||
dest_models = dest_directory / 'models-3.0'
|
||||
dest_yaml = dest_directory / 'configs/models.yaml-3.0'
|
||||
version_3 = (dest_directory / 'models' / 'core').exists()
|
||||
|
||||
# Here we create the destination models.yaml file.
|
||||
# If we are writing into a version 3 directory and the
|
||||
# file already exists, then we write into a copy of it to
|
||||
# avoid deleting its previous customizations. Otherwise we
|
||||
# create a new empty one.
|
||||
if version_3: # write into the dest directory
|
||||
try:
|
||||
shutil.copy(dest_directory / 'configs' / 'models.yaml', config_file)
|
||||
except:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||
(dest_directory / 'models').replace(dest_models)
|
||||
else:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file)
|
||||
|
||||
paths = get_legacy_embeddings(src_directory)
|
||||
migrator = MigrateTo3(
|
||||
from_root = src_directory,
|
||||
to_models = dest_models,
|
||||
model_manager = mgr,
|
||||
src_paths = paths
|
||||
)
|
||||
migrator.migrate()
|
||||
print("Migration successful.")
|
||||
|
||||
with open(dest_yaml,'w') as yaml_file:
|
||||
migrator = MigrateTo3(src_directory,
|
||||
dest_models,
|
||||
yaml_file,
|
||||
src_paths = paths,
|
||||
)
|
||||
migrator.migrate()
|
||||
|
||||
shutil.rmtree(dest_directory / 'models.orig', ignore_errors=True)
|
||||
(dest_directory / 'models').replace(dest_directory / 'models.orig')
|
||||
dest_models.replace(dest_directory / 'models')
|
||||
|
||||
(dest_directory /'configs/models.yaml').replace(dest_directory / 'configs/models.yaml.orig')
|
||||
dest_yaml.replace(dest_directory / 'configs/models.yaml')
|
||||
print(f"""Migration successful.
|
||||
Original models directory moved to {dest_directory}/models.orig
|
||||
Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig
|
||||
""")
|
||||
|
||||
if not version_3:
|
||||
(dest_directory / 'models').replace(src_directory / 'models.orig')
|
||||
print(f'Original models directory moved to {dest_directory}/models.orig')
|
||||
|
||||
(dest_directory / 'configs' / 'models.yaml').replace(src_directory / 'configs' / 'models.yaml.orig')
|
||||
print(f'Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig')
|
||||
|
||||
config_file.replace(config_file.with_suffix(''))
|
||||
dest_models.replace(dest_models.with_suffix(''))
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="invokeai-migrate3",
|
||||
description="""
|
||||
@ -545,34 +570,34 @@ It is safe to provide the same directory for both arguments, but it is better to
|
||||
script, which will perform a full upgrade in place."""
|
||||
)
|
||||
parser.add_argument('--from-directory',
|
||||
dest='root_directory',
|
||||
dest='src_root',
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")'
|
||||
)
|
||||
parser.add_argument('--to-directory',
|
||||
dest='dest_directory',
|
||||
dest='dest_root',
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
|
||||
)
|
||||
# TO DO: Implement full directory scanning
|
||||
# parser.add_argument('--all-models',
|
||||
# action="store_true",
|
||||
# help='Migrate all models found in `models` directory, not just those mentioned in models.yaml',
|
||||
# )
|
||||
args = parser.parse_args()
|
||||
root_directory = args.root_directory
|
||||
assert root_directory.is_dir(), f"{root_directory} is not a valid directory"
|
||||
assert (root_directory / 'models').is_dir(), f"{root_directory} does not contain a 'models' subdirectory"
|
||||
assert (root_directory / 'invokeai.init').exists() or (root_directory / 'invokeai.yaml').exists(), f"{root_directory} does not contain an InvokeAI init file."
|
||||
src_root = args.src_root
|
||||
assert src_root.is_dir(), f"{src_root} is not a valid directory"
|
||||
assert (src_root / 'models').is_dir(), f"{src_root} does not contain a 'models' subdirectory"
|
||||
assert (src_root / 'models' / 'hub').exists(), f"{src_root} does not contain a version 2.3 models directory"
|
||||
assert (src_root / 'invokeai.init').exists() or (src_root / 'invokeai.yaml').exists(), f"{src_root} does not contain an InvokeAI init file."
|
||||
|
||||
dest_directory = args.dest_directory
|
||||
assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory"
|
||||
assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory"
|
||||
assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file."
|
||||
dest_root = args.dest_root
|
||||
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args(['--root',str(dest_root)])
|
||||
|
||||
do_migrate(root_directory,dest_directory)
|
||||
# TODO: revisit
|
||||
# assert (dest_root / 'models').is_dir(), f"{dest_root} does not contain a 'models' subdirectory"
|
||||
# assert (dest_root / 'invokeai.yaml').exists(), f"{dest_root} does not contain an InvokeAI init file."
|
||||
|
||||
do_migrate(src_root,dest_root)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -11,6 +11,7 @@ from typing import List, Dict, Callable, Union, Set
|
||||
|
||||
import requests
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
@ -18,7 +19,7 @@ from tqdm import tqdm
|
||||
import invokeai.configs as configs
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||
from invokeai.backend.util import download_with_resume
|
||||
from ..util.logging import InvokeAILogger
|
||||
@ -153,6 +154,9 @@ class ModelInstall(object):
|
||||
return defaults[0]
|
||||
|
||||
def install(self, selections: InstallSelections):
|
||||
verbosity = dlogging.get_verbosity() # quench NSFW nags
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
job = 1
|
||||
jobs = len(selections.remove_models) + len(selections.install_models)
|
||||
|
||||
@ -160,79 +164,87 @@ class ModelInstall(object):
|
||||
for key in selections.remove_models:
|
||||
name,base,mtype = self.mgr.parse_key(key)
|
||||
logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]')
|
||||
self.mgr.del_model(name,base,mtype)
|
||||
try:
|
||||
self.mgr.del_model(name,base,mtype)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(e)
|
||||
job += 1
|
||||
|
||||
# add requested models
|
||||
for path in selections.install_models:
|
||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||
self.heuristic_install(path)
|
||||
try:
|
||||
self.heuristic_import(path)
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error(str(e))
|
||||
job += 1
|
||||
|
||||
|
||||
dlogging.set_verbosity(verbosity)
|
||||
self.mgr.commit()
|
||||
|
||||
def heuristic_install(self,
|
||||
model_path_id_or_url: Union[str,Path],
|
||||
models_installed: Set[Path]=None)->Set[Path]:
|
||||
def heuristic_import(self,
|
||||
model_path_id_or_url: Union[str,Path],
|
||||
models_installed: Set[Path]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||
:param models_installed: Set of installed models, used for recursive invocation
|
||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||
'''
|
||||
|
||||
if not models_installed:
|
||||
models_installed = set()
|
||||
models_installed = dict()
|
||||
|
||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||
self.current_id = model_path_id_or_url
|
||||
path = Path(model_path_id_or_url)
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.update({str(path):self._install_path(path)})
|
||||
|
||||
try:
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.add(self._install_path(path))
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||
]
|
||||
):
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||
models_installed.add(self._install_path(path))
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_install(child, models_installed=models_installed)
|
||||
# huggingface repo
|
||||
elif len(str(model_path_id_or_url).split('/')) == 2:
|
||||
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||
|
||||
# huggingface repo
|
||||
elif len(str(path).split('/')) == 2:
|
||||
models_installed.add(self._install_repo(str(path)))
|
||||
# a URL
|
||||
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||
|
||||
# a URL
|
||||
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.add(self._install_url(model_path_id_or_url))
|
||||
|
||||
else:
|
||||
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
else:
|
||||
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
|
||||
return models_installed
|
||||
|
||||
# install a model from a local path. The optional info parameter is there to prevent
|
||||
# the model from being probed twice in the event that it has already been probed.
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
|
||||
try:
|
||||
# logger.debug(f'Probing {path}')
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
model_name = path.stem if info.format=='checkpoint' else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'{str(e)} Skipping registration.')
|
||||
return path
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f'Unable to parse format of {path}')
|
||||
return None
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
return self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
|
||||
def _install_url(self, url: str)->Path:
|
||||
# copy to a staging area, probe, import and delete
|
||||
def _install_url(self, url: str)->AddModelResult:
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
location = download_with_resume(url,Path(staging))
|
||||
if not location:
|
||||
@ -244,7 +256,7 @@ class ModelInstall(object):
|
||||
# staged version will be garbage-collected at this time
|
||||
return self._install_path(Path(models_path), info)
|
||||
|
||||
def _install_repo(self, repo_id: str)->Path:
|
||||
def _install_repo(self, repo_id: str)->AddModelResult:
|
||||
hinfo = HfApi().model_info(repo_id)
|
||||
|
||||
# we try to figure out how to download this most economically
|
||||
@ -270,16 +282,16 @@ class ModelInstall(object):
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
break
|
||||
elif f'learned_embeds.{suffix}' in files:
|
||||
location = self._download_hf_model(repo_id, ['learned_embeds.suffix'], staging)
|
||||
location = self._download_hf_model(repo_id, [f'learned_embeds.{suffix}'], staging)
|
||||
break
|
||||
if not location:
|
||||
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
|
||||
return
|
||||
|
||||
return {}
|
||||
|
||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f'Could not probe {location}. Skipping install.')
|
||||
return
|
||||
return {}
|
||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
|
||||
if dest.exists():
|
||||
shutil.rmtree(dest)
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||
from .model_cache import ModelCache
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||
|
||||
|
@ -29,7 +29,7 @@ import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from .model_manager import ModelManager
|
||||
from .model_cache import ModelCache
|
||||
from picklescan.scanner import scan_file_path
|
||||
from .models import BaseModelType, ModelVariantType
|
||||
|
||||
try:
|
||||
@ -1014,7 +1014,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint = load_file(checkpoint_path)
|
||||
else:
|
||||
if scan_needed:
|
||||
ModelCache.scan_model(checkpoint_path, checkpoint_path)
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint_path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
|
||||
# sometimes there is a state_dict key and sometimes not
|
||||
|
@ -1,18 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
from typing import Optional, Dict, Tuple, Any, Union, List
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
class LoRALayerBase:
|
||||
#rank: Optional[int]
|
||||
@ -124,8 +121,8 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
def get_weight(self):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
@ -411,7 +408,7 @@ class LoRAModel: #(torch.nn.Module):
|
||||
else:
|
||||
# TODO: diff/ia3/... format
|
||||
print(
|
||||
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
|
||||
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
|
||||
)
|
||||
return
|
||||
|
||||
@ -539,9 +536,10 @@ class ModelPatcher:
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
# enable autocast to calc fp16 loras on cpu
|
||||
with torch.autocast(device_type="cpu"):
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||
#with torch.autocast(device_type="cpu"):
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
@ -617,6 +615,24 @@ class ModelPatcher:
|
||||
text_encoder.resize_token_embeddings(init_tokens_count)
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_clip_skip(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
clip_skip: int,
|
||||
):
|
||||
skipped_layers = []
|
||||
try:
|
||||
for i in range(clip_skip):
|
||||
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
while len(skipped_layers) > 0:
|
||||
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
||||
|
||||
class TextualInversionModel:
|
||||
name: str
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
@ -655,6 +671,9 @@ class TextualInversionModel:
|
||||
else:
|
||||
result.embedding = next(iter(state_dict.values()))
|
||||
|
||||
if len(result.embedding.shape) == 1:
|
||||
result.embedding = result.embedding.unsqueeze(0)
|
||||
|
||||
if not isinstance(result.embedding, torch.Tensor):
|
||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||
|
||||
|
@ -8,7 +8,7 @@ The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_models_cached=6)
|
||||
cache = ModelCache(max_cache_size=7.5)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
@ -36,6 +36,9 @@ from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE= 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
@ -82,6 +85,7 @@ class ModelCache(object):
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float=DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
execution_device: torch.device=torch.device('cuda'),
|
||||
storage_device: torch.device=torch.device('cpu'),
|
||||
precision: torch.dtype=torch.float16,
|
||||
@ -91,7 +95,7 @@ class ModelCache(object):
|
||||
logger: types.ModuleType = logger
|
||||
):
|
||||
'''
|
||||
:param max_models: Maximum number of models to cache in CPU RAM [4]
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
@ -99,14 +103,11 @@ class ModelCache(object):
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
'''
|
||||
#max_cache_size = 9999
|
||||
execution_device = torch.device('cuda')
|
||||
|
||||
self.model_infos: Dict[str, ModelBase] = dict()
|
||||
self.lazy_offloading = lazy_offloading
|
||||
#self.sequential_offload: bool=sequential_offload
|
||||
self.precision: torch.dtype=precision
|
||||
self.max_cache_size: int=max_cache_size
|
||||
self.max_cache_size: float=max_cache_size
|
||||
self.max_vram_cache_size: float=max_vram_cache_size
|
||||
self.execution_device: torch.device=execution_device
|
||||
self.storage_device: torch.device=storage_device
|
||||
self.sha_chunksize=sha_chunksize
|
||||
@ -128,16 +129,6 @@ class ModelCache(object):
|
||||
key += f":{submodel_type}"
|
||||
return key
|
||||
|
||||
#def get_model(
|
||||
# self,
|
||||
# repo_id_or_path: Union[str, Path],
|
||||
# model_type: ModelType = ModelType.Diffusers,
|
||||
# subfolder: Path = None,
|
||||
# submodel: ModelType = None,
|
||||
# revision: str = None,
|
||||
# attach_model_part: Tuple[ModelType, str] = (None, None),
|
||||
# gpu_load: bool = True,
|
||||
#) -> ModelLocker: # ?? what does it return
|
||||
def _get_model_info(
|
||||
self,
|
||||
model_path: str,
|
||||
@ -213,14 +204,22 @@ class ModelCache(object):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
|
||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load)
|
||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
|
||||
|
||||
class ModelLocker(object):
|
||||
def __init__(self, cache, key, model, gpu_load):
|
||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
||||
'''
|
||||
:param cache: The model_cache object
|
||||
:param key: The key of the model to lock in GPU
|
||||
:param model: The model to lock
|
||||
:param gpu_load: True if load into gpu
|
||||
:param size_needed: Size of the model to load
|
||||
'''
|
||||
self.gpu_load = gpu_load
|
||||
self.cache = cache
|
||||
self.key = key
|
||||
self.model = model
|
||||
self.size_needed = size_needed
|
||||
self.cache_entry = self.cache._cached_models[self.key]
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
@ -234,7 +233,7 @@ class ModelCache(object):
|
||||
|
||||
try:
|
||||
if self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models()
|
||||
self.cache._offload_unlocked_models(self.size_needed)
|
||||
|
||||
if self.model.device != self.cache.execution_device:
|
||||
self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}')
|
||||
@ -349,12 +348,20 @@ class ModelCache(object):
|
||||
|
||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
|
||||
def _offload_unlocked_models(self):
|
||||
for model_key, cache_entry in self._cached_models.items():
|
||||
def _offload_unlocked_models(self, size_needed: int=0):
|
||||
reserved = self.max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated()
|
||||
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
|
||||
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x:x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.locked and cache_entry.loaded:
|
||||
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
|
||||
cache_entry.model.to(self.storage_device)
|
||||
with VRAMUsage() as mem:
|
||||
cache_entry.model.to(self.storage_device)
|
||||
self.logger.debug(f'GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB')
|
||||
vram_in_use += mem.vram_used # note vram_used is negative
|
||||
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||
sha = hashlib.sha256()
|
||||
|
@ -52,7 +52,7 @@ A typical example is:
|
||||
sd1_5 = mgr.get_model('stable-diffusion-v1-5',
|
||||
model_type=ModelType.Main,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
submodel_type=SubModelType.Unet)
|
||||
submodel_type=SubModelType.UNet)
|
||||
with sd1_5 as unet:
|
||||
run_some_inference(unet)
|
||||
|
||||
@ -231,16 +231,17 @@ from __future__ import annotations
|
||||
import os
|
||||
import hashlib
|
||||
import textwrap
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Union, Set, Callable, types
|
||||
from shutil import rmtree
|
||||
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
from shutil import rmtree, move
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@ -249,7 +250,7 @@ from .model_cache import ModelCache, ModelLocker
|
||||
from .models import (
|
||||
BaseModelType, ModelType, SubModelType,
|
||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||
ModelConfigBase,
|
||||
ModelConfigBase, ModelNotFoundException,
|
||||
)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
@ -278,8 +279,13 @@ class InvalidModelError(Exception):
|
||||
"Raised when an invalid model is requested"
|
||||
pass
|
||||
|
||||
MAX_CACHE_SIZE = 6.0 # GB
|
||||
class AddModelResult(BaseModel):
|
||||
name: str = Field(description="The name of the model after installation")
|
||||
model_type: ModelType = Field(description="The type of model")
|
||||
base_model: BaseModelType = Field(description="The base model")
|
||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||
|
||||
MAX_CACHE_SIZE = 6.0 # GB
|
||||
|
||||
class ConfigMeta(BaseModel):
|
||||
version: str
|
||||
@ -306,10 +312,12 @@ class ModelManager(object):
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
|
||||
self.config_path = None
|
||||
if isinstance(config, (str, Path)):
|
||||
self.config_path = Path(config)
|
||||
if not self.config_path.exists():
|
||||
logger.warning(f'The file {self.config_path} was not found. Initializing a new file')
|
||||
self.initialize_model_config(self.config_path)
|
||||
config = OmegaConf.load(self.config_path)
|
||||
|
||||
elif not isinstance(config, DictConfig):
|
||||
@ -332,6 +340,7 @@ class ModelManager(object):
|
||||
self.logger = logger
|
||||
self.cache = ModelCache(
|
||||
max_cache_size=max_cache_size,
|
||||
max_vram_cache_size = self.app_config.max_vram_cache_size,
|
||||
execution_device = device_type,
|
||||
precision = precision,
|
||||
sequential_offload = sequential_offload,
|
||||
@ -382,6 +391,16 @@ class ModelManager(object):
|
||||
def _get_model_cache_path(self, model_path):
|
||||
return self.app_config.models_path / ".cache" / hashlib.md5(str(model_path).encode()).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def initialize_model_config(cls, config_path: Path):
|
||||
"""Create empty config file"""
|
||||
with open(config_path,'w') as yaml_file:
|
||||
yaml_file.write(yaml.dump({'__metadata__':
|
||||
{'version':'3.0.0'}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -404,7 +423,7 @@ class ModelManager(object):
|
||||
if model_key not in self.models:
|
||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||
if model_key not in self.models:
|
||||
raise Exception(f"Model not found - {model_key}")
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
model_config = self.models[model_key]
|
||||
model_path = self.app_config.root_path / model_config.path
|
||||
@ -416,14 +435,14 @@ class ModelManager(object):
|
||||
|
||||
else:
|
||||
self.models.pop(model_key, None)
|
||||
raise Exception(f"Model not found - {model_key}")
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
# vae/movq override
|
||||
# TODO:
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
override_path = getattr(model_config, submodel_type)
|
||||
if override_path:
|
||||
model_path = override_path
|
||||
model_path = self.app_config.root_path / override_path
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
@ -431,6 +450,7 @@ class ModelManager(object):
|
||||
# TODO: path
|
||||
# TODO: is it accurate to use path as id
|
||||
dst_convert_path = self._get_model_cache_path(model_path)
|
||||
|
||||
model_path = model_class.convert_if_required(
|
||||
base_model=base_model,
|
||||
model_path=str(model_path), # TODO: refactor str/Path types logic
|
||||
@ -485,17 +505,32 @@ class ModelManager(object):
|
||||
"""
|
||||
return [(self.parse_key(x)) for x in self.models.keys()]
|
||||
|
||||
def list_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> dict:
|
||||
"""
|
||||
Returns a dict describing one installed model, using
|
||||
the combined format of the list_models() method.
|
||||
"""
|
||||
models = self.list_models(base_model,model_type,model_name)
|
||||
return models[0] if models else None
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_name: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
|
||||
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
||||
models = []
|
||||
for model_key in sorted(self.models, key=str.casefold):
|
||||
for model_key in model_keys:
|
||||
model_config = self.models[model_key]
|
||||
|
||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||
@ -540,10 +575,7 @@ class ModelManager(object):
|
||||
model_cfg = self.models.pop(model_key, None)
|
||||
|
||||
if model_cfg is None:
|
||||
self.logger.error(
|
||||
f"Unknown model {model_key}"
|
||||
)
|
||||
return
|
||||
raise KeyError(f"Unknown model {model_key}")
|
||||
|
||||
# note: it not garantie to release memory(model can has other references)
|
||||
cache_ids = self.cache_keys.pop(model_key, [])
|
||||
@ -570,13 +602,16 @@ class ModelManager(object):
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> None:
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory and the
|
||||
method will return True. Will fail with an assertion error if provided
|
||||
attributes are incorrect or the model name is missing.
|
||||
|
||||
The returned dict has the same format as the dict returned by
|
||||
model_info().
|
||||
"""
|
||||
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
@ -600,13 +635,74 @@ class ModelManager(object):
|
||||
old_model_cache.unlink()
|
||||
|
||||
# remove in-memory cache
|
||||
# note: it not garantie to release memory(model can has other references)
|
||||
# note: it not guaranteed to release memory(model can has other references)
|
||||
cache_ids = self.cache_keys.pop(model_key, [])
|
||||
for cache_id in cache_ids:
|
||||
self.cache.uncache_model(cache_id)
|
||||
|
||||
self.models[model_key] = model_config
|
||||
self.commit()
|
||||
return AddModelResult(
|
||||
name = model_name,
|
||||
model_type = model_type,
|
||||
base_model = base_model,
|
||||
config = model_config,
|
||||
)
|
||||
|
||||
def convert_model (
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
'''
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
'''
|
||||
info = self.model_info(model_name, base_model, model_type)
|
||||
if info["model_format"] != "checkpoint":
|
||||
raise ValueError(f"not a checkpoint format model: {model_name}")
|
||||
|
||||
# We are taking advantage of a side effect of get_model() that converts check points
|
||||
# into cached diffusers directories stored at `location`. It doesn't matter
|
||||
# what submodeltype we request here, so we get the smallest.
|
||||
submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {}
|
||||
model = self.get_model(model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
**submodel,
|
||||
)
|
||||
checkpoint_path = self.app_config.root_path / info["path"]
|
||||
old_diffusers_path = self.app_config.models_path / model.location
|
||||
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
|
||||
if new_diffusers_path.exists():
|
||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||
|
||||
try:
|
||||
move(old_diffusers_path,new_diffusers_path)
|
||||
info["model_format"] = "diffusers"
|
||||
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||
info.pop('config')
|
||||
|
||||
result = self.add_model(model_name, base_model, model_type,
|
||||
model_attributes = info,
|
||||
clobber=True)
|
||||
except:
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
rmtree(new_diffusers_path)
|
||||
raise
|
||||
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
|
||||
checkpoint_path.unlink()
|
||||
|
||||
return result
|
||||
|
||||
def search_models(self, search_folder):
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||
@ -688,6 +784,7 @@ class ModelManager(object):
|
||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||
if model_class.save_to_config:
|
||||
model_config.error = ModelError.NotFound
|
||||
self.models.pop(model_key, None)
|
||||
else:
|
||||
self.models.pop(model_key, None)
|
||||
else:
|
||||
@ -716,19 +813,19 @@ class ModelManager(object):
|
||||
|
||||
if model_path.is_relative_to(self.app_config.root_path):
|
||||
model_path = model_path.relative_to(self.app_config.root_path)
|
||||
try:
|
||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||
self.models[model_key] = model_config
|
||||
new_models_found = True
|
||||
except NotImplementedError as e:
|
||||
self.logger.warning(e)
|
||||
try:
|
||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||
self.models[model_key] = model_config
|
||||
new_models_found = True
|
||||
except NotImplementedError as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
imported_models = self.autoimport()
|
||||
|
||||
if (new_models_found or imported_models) and self.config_path:
|
||||
self.commit()
|
||||
|
||||
def autoimport(self)->set[Path]:
|
||||
def autoimport(self)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||
'''
|
||||
@ -741,7 +838,6 @@ class ModelManager(object):
|
||||
prediction_type_helper = ask_user_for_prediction_type,
|
||||
)
|
||||
|
||||
installed = set()
|
||||
scanned_dirs = set()
|
||||
|
||||
config = self.app_config
|
||||
@ -755,13 +851,14 @@ class ModelManager(object):
|
||||
continue
|
||||
|
||||
self.logger.info(f'Scanning {autodir} for models to import')
|
||||
installed = dict()
|
||||
|
||||
autodir = self.app_config.root_path / autodir
|
||||
if not autodir.exists():
|
||||
continue
|
||||
|
||||
items_scanned = 0
|
||||
new_models_found = set()
|
||||
new_models_found = dict()
|
||||
|
||||
for root, dirs, files in os.walk(autodir):
|
||||
items_scanned += len(dirs) + len(files)
|
||||
@ -770,16 +867,23 @@ class ModelManager(object):
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||
new_models_found.update(installer.heuristic_install(path))
|
||||
scanned_dirs.add(path)
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||
try:
|
||||
new_models_found.update(installer.heuristic_import(path))
|
||||
scanned_dirs.add(path)
|
||||
except ValueError as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
new_models_found.update(installer.heuristic_install(path))
|
||||
try:
|
||||
import_result = installer.heuristic_import(path)
|
||||
new_models_found.update(import_result)
|
||||
except ValueError as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||
installed.update(new_models_found)
|
||||
@ -789,7 +893,7 @@ class ModelManager(object):
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Set[str]:
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
@ -802,20 +906,23 @@ class ModelManager(object):
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
|
||||
May return the following exceptions:
|
||||
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
|
||||
- ValueError - a corresponding model already exists
|
||||
'''
|
||||
# avoid circular import here
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
successfully_installed = set()
|
||||
successfully_installed = dict()
|
||||
|
||||
installer = ModelInstall(config = self.app_config,
|
||||
prediction_type_helper = prediction_type_helper,
|
||||
model_manager = self)
|
||||
for thing in items_to_import:
|
||||
try:
|
||||
installed = installer.heuristic_install(thing)
|
||||
successfully_installed.update(installed)
|
||||
except Exception as e:
|
||||
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
||||
|
||||
installed = installer.heuristic_import(thing)
|
||||
successfully_installed.update(installed)
|
||||
self.commit()
|
||||
return successfully_installed
|
||||
|
131
invokeai/backend/model_management/model_merge.py
Normal file
131
invokeai/backend/model_management/model_merge.py
Normal file
@ -0,0 +1,131 @@
|
||||
"""
|
||||
invokeai.backend.model_management.model_merge exports:
|
||||
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
||||
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from typing import List, Union
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||
|
||||
class MergeInterpolationMethod(str, Enum):
|
||||
WeightedSum = "weighted_sum"
|
||||
Sigmoid = "sigmoid"
|
||||
InvSigmoid = "inv_sigmoid"
|
||||
AddDifference = "add_difference"
|
||||
|
||||
class ModelMerger(object):
|
||||
def __init__(self, manager: ModelManager):
|
||||
self.manager = manager
|
||||
|
||||
def merge_diffusion_models(
|
||||
self,
|
||||
model_paths: List[Path],
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_paths[0],
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_paths,
|
||||
alpha=alpha,
|
||||
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_save (
|
||||
self,
|
||||
model_names: List[str],
|
||||
base_model: Union[BaseModelType,str],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||
:param base_model: base model (must be the same for all merged models!)
|
||||
:param merged_model_name: name for new model
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
model_paths = list()
|
||||
config = self.manager.app_config
|
||||
base_model = BaseModelType(base_model)
|
||||
vae = None
|
||||
|
||||
for mod in model_names:
|
||||
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
||||
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
|
||||
assert len(model_names) <= 2 or \
|
||||
interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported"
|
||||
# pick up the first model's vae
|
||||
if mod == model_names[0]:
|
||||
vae = info.get("vae")
|
||||
model_paths.extend([config.root_path / info["path"]])
|
||||
|
||||
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
|
||||
logger.debug(f'interp = {interp}, merge_method={merge_method}')
|
||||
merged_pipe = self.merge_diffusion_models(
|
||||
model_paths, alpha, merge_method, force, **kwargs
|
||||
)
|
||||
dump_path = config.models_path / base_model.value / ModelType.Main.value
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
attributes = dict(
|
||||
path = str(dump_path),
|
||||
description = f"Merge of models {', '.join(model_names)}",
|
||||
model_format = "diffusers",
|
||||
variant = ModelVariantType.Normal.value,
|
||||
vae = vae,
|
||||
)
|
||||
return self.manager.add_model(merged_model_name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
model_attributes = attributes,
|
||||
clobber = True
|
||||
)
|
@ -6,7 +6,7 @@ from dataclasses import dataclass
|
||||
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, Union, Dict
|
||||
from typing import Callable, Literal, Union, Dict, Optional
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from .models import (
|
||||
@ -59,13 +59,13 @@ class ModelProbe(object):
|
||||
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
||||
else:
|
||||
raise Exception("model parameter {model} is neither a Path, nor a model")
|
||||
raise ValueError("model parameter {model} is neither a Path, nor a model")
|
||||
|
||||
@classmethod
|
||||
def probe(cls,
|
||||
model_path: Path,
|
||||
model: Union[Dict, ModelMixin] = None,
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType] = None)->ModelProbeInfo:
|
||||
model: Optional[Union[Dict, ModelMixin]] = None,
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]] = None)->ModelProbeInfo:
|
||||
'''
|
||||
Probe the model at model_path and return sufficient information about it
|
||||
to place it somewhere in the models directory hierarchy. If the model is
|
||||
@ -78,7 +78,6 @@ class ModelProbe(object):
|
||||
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
||||
else:
|
||||
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
||||
|
||||
model_info = None
|
||||
try:
|
||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||
@ -105,7 +104,7 @@ class ModelProbe(object):
|
||||
) else 512,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
raise
|
||||
|
||||
return model_info
|
||||
|
||||
@ -127,6 +126,8 @@ class ModelProbe(object):
|
||||
return ModelType.Vae
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
@ -137,7 +138,7 @@ class ModelProbe(object):
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
raise ValueError("Unable to determine model type")
|
||||
raise ValueError(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||
@ -167,7 +168,7 @@ class ModelProbe(object):
|
||||
return type
|
||||
|
||||
# give up
|
||||
raise ValueError("Unable to determine model type")
|
||||
raise ValueError(f"Unable to determine model type for {folder_path}")
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||
@ -236,7 +237,7 @@ class CheckpointProbeBase(ProbeBase):
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Cannot determine variant type")
|
||||
raise ValueError(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
@ -247,7 +248,7 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion1
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
raise Exception("Cannot determine base type")
|
||||
raise ValueError("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||
type = self.get_base_type()
|
||||
@ -328,7 +329,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
raise Exception("Unable to determine base type for {self.checkpoint_path}")
|
||||
raise ValueError("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
@ -417,7 +418,7 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
config_file = self.folder_path / 'config.json'
|
||||
if not config_file.exists():
|
||||
raise Exception(f"Cannot determine base type for {self.folder_path}")
|
||||
raise ValueError(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file,'r') as file:
|
||||
config = json.load(file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
@ -434,7 +435,7 @@ class LoRAFolderProbe(FolderProbeBase):
|
||||
model_file = base_file
|
||||
break
|
||||
if not model_file:
|
||||
raise Exception('Unknown LoRA format encountered')
|
||||
raise ValueError('Unknown LoRA format encountered')
|
||||
return LoRACheckpointProbe(model_file,None).get_base_type()
|
||||
|
||||
############## register probe classes ######
|
||||
|
@ -2,7 +2,7 @@ import inspect
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
@ -68,7 +68,11 @@ def get_model_config_enums():
|
||||
enums = list()
|
||||
|
||||
for model_config in MODEL_CONFIGS:
|
||||
fields = inspect.get_annotations(model_config)
|
||||
|
||||
if hasattr(inspect,'get_annotations'):
|
||||
fields = inspect.get_annotations(model_config)
|
||||
else:
|
||||
fields = model_config.__annotations__
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
|
@ -15,6 +15,9 @@ from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
class ModelNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
|
@ -116,7 +116,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
version=BaseModelType.StableDiffusion1,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
@ -8,6 +8,7 @@ from .base import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in textual inversion")
|
||||
|
||||
checkpoint_path = self.model_path
|
||||
if os.path.isdir(checkpoint_path):
|
||||
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
|
||||
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
model = TextualInversionModelRaw.from_checkpoint(
|
||||
file_path=self.model_path,
|
||||
file_path=checkpoint_path,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
|
||||
|
@ -7,7 +7,7 @@ import secrets
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
import einops
|
||||
import PIL.Image
|
||||
@ -17,12 +17,11 @@ import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
@ -46,7 +45,7 @@ from .diffusion import (
|
||||
InvokeAIDiffuserComponent,
|
||||
PostprocessingSettings,
|
||||
)
|
||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||
from .offloading import FullyLoadedModelGroup, ModelGroup
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
@ -105,7 +104,7 @@ class AddsMaskGuidance:
|
||||
_debug: Optional[Callable] = None
|
||||
|
||||
def __call__(
|
||||
self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning
|
||||
self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning
|
||||
) -> BaseOutput:
|
||||
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||
|
||||
@ -361,37 +360,34 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
):
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
if torch.backends.mps.is_available():
|
||||
# until pytorch #91617 is fixed, slicing is borked on MPS
|
||||
# https://github.com/pytorch/pytorch/issues/91617
|
||||
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
|
||||
pass
|
||||
if self.device.type == "cpu" or self.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
|
||||
else:
|
||||
if self.device.type == "cpu" or self.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = (
|
||||
latents.element_size() + 4
|
||||
)
|
||||
max_size_required_for_baddbmm = (
|
||||
16
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* bytes_per_element_needed_for_baddbmm_duplication
|
||||
)
|
||||
if max_size_required_for_baddbmm > (
|
||||
mem_free * 3.0 / 4.0
|
||||
): # 3.3 / 4.0 is from old Invoke code
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
raise ValueError(f"unrecognized device {self.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = (
|
||||
latents.element_size() + 4
|
||||
)
|
||||
max_size_required_for_baddbmm = (
|
||||
16
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* bytes_per_element_needed_for_baddbmm_duplication
|
||||
)
|
||||
if max_size_required_for_baddbmm > (
|
||||
mem_free * 3.0 / 4.0
|
||||
): # 3.3 / 4.0 is from old Invoke code
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
elif torch.backends.mps.is_available():
|
||||
# diffusers recommends always enabling for mps
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
# overridden method; types match the superclass.
|
||||
@ -917,20 +913,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
with torch.inference_mode():
|
||||
if device.type == "mps":
|
||||
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
|
||||
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
|
||||
self.vae.to(CPU_DEVICE)
|
||||
init_image = init_image.to(CPU_DEVICE)
|
||||
else:
|
||||
self._model_group.load(self.vae)
|
||||
self._model_group.load(self.vae)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(
|
||||
dtype=dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
if device.type == "mps":
|
||||
self.vae.to(device)
|
||||
init_latents = init_latents.to(device)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
return init_latents
|
||||
|
@ -248,9 +248,6 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||
)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _apply_standard_conditioning_sequentially(
|
||||
@ -264,9 +261,6 @@ class InvokeAIDiffuserComponent:
|
||||
# low-memory sequential path
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
# TODO: looks unused
|
||||
|
@ -4,7 +4,7 @@ import warnings
|
||||
import weakref
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Callable
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
from accelerate.utils import send_to_device
|
||||
@ -117,7 +117,7 @@ class LazilyLoadedModelGroup(ModelGroup):
|
||||
"""
|
||||
|
||||
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
|
||||
_current_model_ref: Callable[[], torch.nn.Module | _NoModel]
|
||||
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
|
||||
|
||||
def __init__(self, execution_device: torch.device):
|
||||
super().__init__(execution_device)
|
||||
|
@ -4,6 +4,7 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
from typing import Union
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
@ -28,6 +29,8 @@ def choose_precision(device: torch.device) -> str:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||
return "float16"
|
||||
elif device.type == "mps":
|
||||
return "float16"
|
||||
return "float32"
|
||||
|
||||
|
||||
@ -49,7 +52,7 @@ def choose_autocast(precision):
|
||||
return nullcontext
|
||||
|
||||
|
||||
def normalize_device(device: str | torch.device) -> torch.device:
|
||||
def normalize_device(device: Union[str, torch.device]) -> torch.device:
|
||||
"""Ensure device has a device index defined, if appropriate."""
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
|
63
invokeai/backend/util/mps_fixes.py
Normal file
63
invokeai/backend/util/mps_fixes.py
Normal file
@ -0,0 +1,63 @@
|
||||
import torch
|
||||
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
torch.empty = torch.zeros
|
||||
|
||||
|
||||
_torch_layer_norm = torch.nn.functional.layer_norm
|
||||
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
if bias is not None:
|
||||
bias = bias.float()
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
|
||||
else:
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
|
||||
|
||||
torch.nn.functional.layer_norm = new_layer_norm
|
||||
|
||||
|
||||
_torch_tensor_permute = torch.Tensor.permute
|
||||
def new_torch_tensor_permute(input, *dims):
|
||||
result = _torch_tensor_permute(input, *dims)
|
||||
if input.device == "mps" and input.dtype == torch.float16:
|
||||
result = result.contiguous()
|
||||
return result
|
||||
|
||||
torch.Tensor.permute = new_torch_tensor_permute
|
||||
|
||||
|
||||
_torch_lerp = torch.lerp
|
||||
def new_torch_lerp(input, end, weight, *, out=None):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
end = end.float()
|
||||
if isinstance(weight, torch.Tensor):
|
||||
weight = weight.float()
|
||||
if out is not None:
|
||||
out_fp32 = torch.zeros_like(out, dtype=torch.float32)
|
||||
else:
|
||||
out_fp32 = None
|
||||
result = _torch_lerp(input, end, weight, out=out_fp32)
|
||||
if out is not None:
|
||||
out.copy_(out_fp32.half())
|
||||
del out_fp32
|
||||
return result.half()
|
||||
|
||||
else:
|
||||
return _torch_lerp(input, end, weight, out=out)
|
||||
|
||||
torch.lerp = new_torch_lerp
|
||||
|
||||
|
||||
_torch_interpolate = torch.nn.functional.interpolate
|
||||
def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half()
|
||||
else:
|
||||
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
|
||||
|
||||
torch.nn.functional.interpolate = new_torch_interpolate
|
Reference in New Issue
Block a user