mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein-improve-ti-frontend
This commit is contained in:
@ -146,7 +146,7 @@ class Generate:
|
||||
gfpgan=None,
|
||||
codeformer=None,
|
||||
esrgan=None,
|
||||
free_gpu_mem=False,
|
||||
free_gpu_mem: bool=False,
|
||||
safety_checker:bool=False,
|
||||
max_loaded_models:int=2,
|
||||
# these are deprecated; if present they override values in the conf file
|
||||
@ -445,7 +445,11 @@ class Generate:
|
||||
self._set_sampler()
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
|
||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
||||
prompt,
|
||||
lambda concepts: self.load_huggingface_concepts(concepts),
|
||||
self.model.textual_inversion_manager.get_all_trigger_strings()
|
||||
)
|
||||
|
||||
# bit of a hack to change the cached sampler's karras threshold to
|
||||
# whatever the user asked for
|
||||
@ -460,10 +464,13 @@ class Generate:
|
||||
init_image = None
|
||||
mask_image = None
|
||||
|
||||
|
||||
if self.free_gpu_mem and self.model.cond_stage_model.device != self.model.device:
|
||||
self.model.cond_stage_model.device = self.model.device
|
||||
self.model.cond_stage_model.to(self.model.device)
|
||||
try:
|
||||
if self.free_gpu_mem and self.model.cond_stage_model.device != self.model.device:
|
||||
self.model.cond_stage_model.device = self.model.device
|
||||
self.model.cond_stage_model.to(self.model.device)
|
||||
except AttributeError:
|
||||
print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.")
|
||||
pass
|
||||
|
||||
try:
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
@ -531,6 +538,7 @@ class Generate:
|
||||
inpaint_height = inpaint_height,
|
||||
inpaint_width = inpaint_width,
|
||||
enable_image_debugging = enable_image_debugging,
|
||||
free_gpu_mem=self.free_gpu_mem,
|
||||
)
|
||||
|
||||
if init_color:
|
||||
|
@ -573,7 +573,7 @@ def import_model(model_path:str, gen, opt, completer):
|
||||
|
||||
if model_path.startswith(('http:','https:','ftp:')):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
elif os.path.exists(model_path) and model_path.endswith('.ckpt') and os.path.isfile(model_path):
|
||||
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
|
||||
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||
@ -627,9 +627,9 @@ def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
|
||||
model_description=default_description
|
||||
)
|
||||
config_file = None
|
||||
|
||||
default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.set_line('configs/stable-diffusion/v1-inference.yaml')
|
||||
completer.set_line(str(default))
|
||||
done = False
|
||||
while not done:
|
||||
config_file = input('Configuration file for this model: ').strip()
|
||||
|
@ -56,9 +56,11 @@ class CkptGenerator():
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
attention_maps_callback = None,
|
||||
free_gpu_mem: bool=False,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
make_image = self.get_make_image(
|
||||
|
@ -21,7 +21,7 @@ import os
|
||||
import re
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from safetensors.torch import load_file
|
||||
|
||||
try:
|
||||
@ -637,7 +637,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14",cache_dir=global_cache_dir('hub'))
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
@ -677,7 +677,8 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def convert_paint_by_example_checkpoint(checkpoint):
|
||||
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
||||
cache_dir = global_cache_dir('hub')
|
||||
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||
model = PaintByExampleImageEncoder(config)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
@ -744,7 +745,8 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||
cache_dir=global_cache_dir('hub')
|
||||
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
@ -795,6 +797,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
):
|
||||
|
||||
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
||||
cache_dir = global_cache_dir('hub')
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
@ -904,7 +907,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer",cache_dir=global_cache_dir('diffusers'))
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
@ -917,8 +920,8 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
)
|
||||
elif model_type == "PaintByExample":
|
||||
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
pipe = PaintByExamplePipeline(
|
||||
vae=vae,
|
||||
image_encoder=vision_model,
|
||||
@ -929,9 +932,9 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
)
|
||||
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
@ -944,7 +947,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
|
||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
pipe.save_pretrained(
|
||||
|
@ -59,7 +59,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
be downloaded.
|
||||
'''
|
||||
if not concept_name in self.list_concepts():
|
||||
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
|
||||
print(f'This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept.')
|
||||
return None
|
||||
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
|
||||
|
||||
@ -115,13 +115,19 @@ class HuggingFaceConceptsLibrary(object):
|
||||
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
||||
return self.match_trigger.sub(do_replace, prompt)
|
||||
|
||||
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
|
||||
def replace_concepts_with_triggers(self,
|
||||
prompt:str,
|
||||
load_concepts_callback: Callable[[list], any],
|
||||
excluded_tokens:list[str])->str:
|
||||
'''
|
||||
Given a prompt string that contains `<concept_name>` tags, replace
|
||||
these tags with the appropriate trigger.
|
||||
|
||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
||||
of `concepts_name` strings.
|
||||
|
||||
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||
are trigger tokens from a locally-loaded embedding.
|
||||
'''
|
||||
concepts = self.match_concept.findall(prompt)
|
||||
if not concepts:
|
||||
@ -129,6 +135,8 @@ class HuggingFaceConceptsLibrary(object):
|
||||
load_concepts_callback(concepts)
|
||||
|
||||
def do_replace(match)->str:
|
||||
if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens:
|
||||
return f'<{match.group(1)}>'
|
||||
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
||||
return self.match_concept.sub(do_replace, prompt)
|
||||
|
||||
|
@ -62,9 +62,11 @@ class Generator:
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
free_gpu_mem: bool=False,
|
||||
**kwargs):
|
||||
scope = nullcontext
|
||||
self.safety_checker = safety_checker
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
make_image = self.get_make_image(
|
||||
|
@ -29,6 +29,7 @@ else:
|
||||
|
||||
# Where to look for the initialization file
|
||||
Globals.initfile = 'invokeai.init'
|
||||
Globals.models_file = 'models.yaml'
|
||||
Globals.models_dir = 'models'
|
||||
Globals.config_dir = 'configs'
|
||||
Globals.autoscan_dir = 'weights'
|
||||
@ -49,6 +50,9 @@ Globals.disable_xformers = False
|
||||
# whether we are forcing full precision
|
||||
Globals.full_precision = False
|
||||
|
||||
def global_config_file()->Path:
|
||||
return Path(Globals.root, Globals.config_dir, Globals.models_file)
|
||||
|
||||
def global_config_dir()->Path:
|
||||
return Path(Globals.root, Globals.config_dir)
|
||||
|
||||
|
62
ldm/invoke/merge_diffusers.py
Normal file
62
ldm/invoke/merge_diffusers.py
Normal file
@ -0,0 +1,62 @@
|
||||
'''
|
||||
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
|
||||
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||
'''
|
||||
import os
|
||||
from typing import List
|
||||
from diffusers import DiffusionPipeline
|
||||
from ldm.invoke.globals import global_config_file, global_models_dir, global_cache_dir
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
def merge_diffusion_models(models:List['str'],
|
||||
merged_model_name:str,
|
||||
alpha:float=0.5,
|
||||
interp:str=None,
|
||||
force:bool=False,
|
||||
**kwargs):
|
||||
'''
|
||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||
merged_model_name = name for new model
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
'''
|
||||
config_file = global_config_file()
|
||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||
for mod in models:
|
||||
assert (mod in model_manager.model_names()), f'** Unknown model "{mod}"'
|
||||
assert (model_manager.model_info(mod).get('format',None) == 'diffusers'), f'** {mod} is not a diffusers model. It must be optimized before merging.'
|
||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get('cache_dir',global_cache_dir()),
|
||||
custom_pipeline='checkpoint_merger')
|
||||
merged_pipe = pipe.merge(pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs)
|
||||
dump_path = global_models_dir() / 'merged_diffusers'
|
||||
os.makedirs(dump_path,exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained (
|
||||
dump_path,
|
||||
safe_serialization=1
|
||||
)
|
||||
model_manager.import_diffuser_model(
|
||||
dump_path,
|
||||
model_name = merged_model_name,
|
||||
description = f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
print('REMINDER: When PR 2369 is merged, replace merge_diffusers.py line 56 with vae= argument to impormodel()')
|
||||
if vae := model_manager.config[models[0]].get('vae',None):
|
||||
print(f'>> Using configured VAE assigned to {models[0]}')
|
||||
model_manager.config[merged_model_name]['vae'] = vae
|
||||
|
||||
model_manager.commit(config_file)
|
@ -37,7 +37,11 @@ from ldm.util import instantiate_from_config, ask_user
|
||||
DEFAULT_MAX_MODELS=2
|
||||
|
||||
class ModelManager(object):
|
||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
||||
def __init__(self,
|
||||
config:OmegaConf,
|
||||
device_type:str='cpu',
|
||||
precision:str='float16',
|
||||
max_loaded_models=DEFAULT_MAX_MODELS):
|
||||
'''
|
||||
Initialize with the path to the models.yaml config file,
|
||||
the torch device type, and precision. The optional
|
||||
@ -143,7 +147,7 @@ class ModelManager(object):
|
||||
Return true if this is a legacy (.ckpt) model
|
||||
'''
|
||||
info = self.model_info(model_name)
|
||||
if 'weights' in info and info['weights'].endswith('.ckpt'):
|
||||
if 'weights' in info and info['weights'].endswith(('.ckpt','.safetensors')):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -362,8 +366,14 @@ class ModelManager(object):
|
||||
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||
if os.path.exists(vae):
|
||||
print(f' | Loading VAE weights from: {vae}')
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
vae_ckpt = None
|
||||
vae_dict = None
|
||||
if vae.endswith('.safetensors'):
|
||||
vae_ckpt = safetensors.torch.load_file(vae)
|
||||
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
|
||||
else:
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt['state_dict'].items() if k[0:4] != "loss"}
|
||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||
else:
|
||||
print(f' | VAE file {vae} not found. Skipping.')
|
||||
@ -536,7 +546,7 @@ class ModelManager(object):
|
||||
format='diffusers',
|
||||
)
|
||||
if isinstance(repo_or_path,Path) and repo_or_path.exists():
|
||||
new_config.update(path=repo_or_path)
|
||||
new_config.update(path=str(repo_or_path))
|
||||
else:
|
||||
new_config.update(repo_id=repo_or_path)
|
||||
|
||||
|
@ -1,18 +1,16 @@
|
||||
import math
|
||||
import os.path
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import clip
|
||||
import kornia
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
import clip
|
||||
from einops import rearrange, repeat
|
||||
from einops import repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import kornia
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
#from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.invoke.globals import global_cache_dir
|
||||
from ldm.modules.x_transformer import (
|
||||
Encoder,
|
||||
TransformerWrapper,
|
||||
@ -654,21 +652,22 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
per_token_weights += [weight] * len(this_fragment_token_ids)
|
||||
|
||||
# leave room for bos/eos
|
||||
if len(all_token_ids) > self.max_length - 2:
|
||||
excess_token_count = len(all_token_ids) - self.max_length - 2
|
||||
max_token_count_without_bos_eos_markers = self.max_length - 2
|
||||
if len(all_token_ids) > max_token_count_without_bos_eos_markers:
|
||||
excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers
|
||||
# TODO build nice description string of how the truncation was applied
|
||||
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
|
||||
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
|
||||
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||
all_token_ids = all_token_ids[0:self.max_length]
|
||||
per_token_weights = per_token_weights[0:self.max_length]
|
||||
all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers]
|
||||
per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers]
|
||||
|
||||
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||
# pad out to a 77-entry array: [bos_token, <prompt tokens>, eos_token, pad_token…]
|
||||
# (77 = self.max_length)
|
||||
all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id]
|
||||
per_token_weights = [1.0] + per_token_weights + [1.0]
|
||||
pad_length = self.max_length - len(all_token_ids)
|
||||
all_token_ids += [self.tokenizer.eos_token_id] * pad_length
|
||||
all_token_ids += [self.tokenizer.pad_token_id] * pad_length
|
||||
per_token_weights += [1.0] * pad_length
|
||||
|
||||
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long).to(self.device)
|
||||
|
@ -3,8 +3,9 @@ import math
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
from ldm.invoke.devices import torch_dtype
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
|
||||
|
||||
class WeightedPromptFragmentsToEmbeddingsConverter():
|
||||
|
||||
@ -22,8 +23,8 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
|
||||
return self.tokenizer.model_max_length
|
||||
|
||||
def get_embeddings_for_weighted_prompt_fragments(self,
|
||||
text: list[str],
|
||||
fragment_weights: list[float],
|
||||
text: list[list[str]],
|
||||
fragment_weights: list[list[float]],
|
||||
should_return_tokens: bool = False,
|
||||
device='cpu'
|
||||
) -> torch.Tensor:
|
||||
@ -198,12 +199,12 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
|
||||
all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers]
|
||||
per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers]
|
||||
|
||||
# pad out to a self.max_length-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||
# pad out to a self.max_length-entry array: [bos_token, <prompt tokens>, eos_token, pad_token…]
|
||||
# (typically self.max_length == 77)
|
||||
all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id]
|
||||
per_token_weights = [1.0] + per_token_weights + [1.0]
|
||||
pad_length = self.max_length - len(all_token_ids)
|
||||
all_token_ids += [self.tokenizer.eos_token_id] * pad_length
|
||||
all_token_ids += [self.tokenizer.pad_token_id] * pad_length
|
||||
per_token_weights += [1.0] * pad_length
|
||||
|
||||
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device)
|
||||
|
@ -38,11 +38,15 @@ class TextualInversionManager():
|
||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||
continue
|
||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||
if self.has_textual_inversion_for_trigger_string(trigger):
|
||||
if self.has_textual_inversion_for_trigger_string(trigger) \
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name) \
|
||||
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
|
||||
print(f'>> Loaded local embedding for trigger {concept_name}')
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
print(f'>> Loaded remote embedding for trigger {concept_name}')
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
||||
|
||||
|
Reference in New Issue
Block a user