mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein-improve-ti-frontend
This commit is contained in:
@ -573,7 +573,7 @@ def import_model(model_path:str, gen, opt, completer):
|
||||
|
||||
if model_path.startswith(('http:','https:','ftp:')):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
elif os.path.exists(model_path) and model_path.endswith('.ckpt') and os.path.isfile(model_path):
|
||||
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
|
||||
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||
@ -627,9 +627,9 @@ def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
|
||||
model_description=default_description
|
||||
)
|
||||
config_file = None
|
||||
|
||||
default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.set_line('configs/stable-diffusion/v1-inference.yaml')
|
||||
completer.set_line(str(default))
|
||||
done = False
|
||||
while not done:
|
||||
config_file = input('Configuration file for this model: ').strip()
|
||||
|
@ -56,9 +56,11 @@ class CkptGenerator():
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
attention_maps_callback = None,
|
||||
free_gpu_mem: bool=False,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
make_image = self.get_make_image(
|
||||
|
@ -21,7 +21,7 @@ import os
|
||||
import re
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from safetensors.torch import load_file
|
||||
|
||||
try:
|
||||
@ -637,7 +637,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14",cache_dir=global_cache_dir('hub'))
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
@ -677,7 +677,8 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def convert_paint_by_example_checkpoint(checkpoint):
|
||||
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
||||
cache_dir = global_cache_dir('hub')
|
||||
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||
model = PaintByExampleImageEncoder(config)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
@ -744,7 +745,8 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||
cache_dir=global_cache_dir('hub')
|
||||
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
@ -795,6 +797,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
):
|
||||
|
||||
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
||||
cache_dir = global_cache_dir('hub')
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
@ -904,7 +907,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer",cache_dir=global_cache_dir('diffusers'))
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
@ -917,8 +920,8 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
)
|
||||
elif model_type == "PaintByExample":
|
||||
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
pipe = PaintByExamplePipeline(
|
||||
vae=vae,
|
||||
image_encoder=vision_model,
|
||||
@ -929,9 +932,9 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
)
|
||||
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
@ -944,7 +947,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
|
||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
pipe.save_pretrained(
|
||||
|
@ -59,7 +59,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
be downloaded.
|
||||
'''
|
||||
if not concept_name in self.list_concepts():
|
||||
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
|
||||
print(f'This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept.')
|
||||
return None
|
||||
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
|
||||
|
||||
@ -115,13 +115,19 @@ class HuggingFaceConceptsLibrary(object):
|
||||
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
||||
return self.match_trigger.sub(do_replace, prompt)
|
||||
|
||||
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
|
||||
def replace_concepts_with_triggers(self,
|
||||
prompt:str,
|
||||
load_concepts_callback: Callable[[list], any],
|
||||
excluded_tokens:list[str])->str:
|
||||
'''
|
||||
Given a prompt string that contains `<concept_name>` tags, replace
|
||||
these tags with the appropriate trigger.
|
||||
|
||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
||||
of `concepts_name` strings.
|
||||
|
||||
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||
are trigger tokens from a locally-loaded embedding.
|
||||
'''
|
||||
concepts = self.match_concept.findall(prompt)
|
||||
if not concepts:
|
||||
@ -129,6 +135,8 @@ class HuggingFaceConceptsLibrary(object):
|
||||
load_concepts_callback(concepts)
|
||||
|
||||
def do_replace(match)->str:
|
||||
if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens:
|
||||
return f'<{match.group(1)}>'
|
||||
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
||||
return self.match_concept.sub(do_replace, prompt)
|
||||
|
||||
|
@ -62,9 +62,11 @@ class Generator:
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
free_gpu_mem: bool=False,
|
||||
**kwargs):
|
||||
scope = nullcontext
|
||||
self.safety_checker = safety_checker
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
make_image = self.get_make_image(
|
||||
|
@ -29,6 +29,7 @@ else:
|
||||
|
||||
# Where to look for the initialization file
|
||||
Globals.initfile = 'invokeai.init'
|
||||
Globals.models_file = 'models.yaml'
|
||||
Globals.models_dir = 'models'
|
||||
Globals.config_dir = 'configs'
|
||||
Globals.autoscan_dir = 'weights'
|
||||
@ -49,6 +50,9 @@ Globals.disable_xformers = False
|
||||
# whether we are forcing full precision
|
||||
Globals.full_precision = False
|
||||
|
||||
def global_config_file()->Path:
|
||||
return Path(Globals.root, Globals.config_dir, Globals.models_file)
|
||||
|
||||
def global_config_dir()->Path:
|
||||
return Path(Globals.root, Globals.config_dir)
|
||||
|
||||
|
62
ldm/invoke/merge_diffusers.py
Normal file
62
ldm/invoke/merge_diffusers.py
Normal file
@ -0,0 +1,62 @@
|
||||
'''
|
||||
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
|
||||
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||
'''
|
||||
import os
|
||||
from typing import List
|
||||
from diffusers import DiffusionPipeline
|
||||
from ldm.invoke.globals import global_config_file, global_models_dir, global_cache_dir
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
def merge_diffusion_models(models:List['str'],
|
||||
merged_model_name:str,
|
||||
alpha:float=0.5,
|
||||
interp:str=None,
|
||||
force:bool=False,
|
||||
**kwargs):
|
||||
'''
|
||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||
merged_model_name = name for new model
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
'''
|
||||
config_file = global_config_file()
|
||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||
for mod in models:
|
||||
assert (mod in model_manager.model_names()), f'** Unknown model "{mod}"'
|
||||
assert (model_manager.model_info(mod).get('format',None) == 'diffusers'), f'** {mod} is not a diffusers model. It must be optimized before merging.'
|
||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get('cache_dir',global_cache_dir()),
|
||||
custom_pipeline='checkpoint_merger')
|
||||
merged_pipe = pipe.merge(pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs)
|
||||
dump_path = global_models_dir() / 'merged_diffusers'
|
||||
os.makedirs(dump_path,exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained (
|
||||
dump_path,
|
||||
safe_serialization=1
|
||||
)
|
||||
model_manager.import_diffuser_model(
|
||||
dump_path,
|
||||
model_name = merged_model_name,
|
||||
description = f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
print('REMINDER: When PR 2369 is merged, replace merge_diffusers.py line 56 with vae= argument to impormodel()')
|
||||
if vae := model_manager.config[models[0]].get('vae',None):
|
||||
print(f'>> Using configured VAE assigned to {models[0]}')
|
||||
model_manager.config[merged_model_name]['vae'] = vae
|
||||
|
||||
model_manager.commit(config_file)
|
@ -37,7 +37,11 @@ from ldm.util import instantiate_from_config, ask_user
|
||||
DEFAULT_MAX_MODELS=2
|
||||
|
||||
class ModelManager(object):
|
||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
||||
def __init__(self,
|
||||
config:OmegaConf,
|
||||
device_type:str='cpu',
|
||||
precision:str='float16',
|
||||
max_loaded_models=DEFAULT_MAX_MODELS):
|
||||
'''
|
||||
Initialize with the path to the models.yaml config file,
|
||||
the torch device type, and precision. The optional
|
||||
@ -143,7 +147,7 @@ class ModelManager(object):
|
||||
Return true if this is a legacy (.ckpt) model
|
||||
'''
|
||||
info = self.model_info(model_name)
|
||||
if 'weights' in info and info['weights'].endswith('.ckpt'):
|
||||
if 'weights' in info and info['weights'].endswith(('.ckpt','.safetensors')):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -362,8 +366,14 @@ class ModelManager(object):
|
||||
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||
if os.path.exists(vae):
|
||||
print(f' | Loading VAE weights from: {vae}')
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
vae_ckpt = None
|
||||
vae_dict = None
|
||||
if vae.endswith('.safetensors'):
|
||||
vae_ckpt = safetensors.torch.load_file(vae)
|
||||
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
|
||||
else:
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt['state_dict'].items() if k[0:4] != "loss"}
|
||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||
else:
|
||||
print(f' | VAE file {vae} not found. Skipping.')
|
||||
@ -536,7 +546,7 @@ class ModelManager(object):
|
||||
format='diffusers',
|
||||
)
|
||||
if isinstance(repo_or_path,Path) and repo_or_path.exists():
|
||||
new_config.update(path=repo_or_path)
|
||||
new_config.update(path=str(repo_or_path))
|
||||
else:
|
||||
new_config.update(repo_id=repo_or_path)
|
||||
|
||||
|
Reference in New Issue
Block a user