mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
conversion script uses invokeai models cache by default
This commit is contained in:
parent
89791d91e8
commit
ffcc5ad795
@ -21,7 +21,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals, global_cache_dir
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -637,7 +637,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
|||||||
|
|
||||||
|
|
||||||
def convert_ldm_clip_checkpoint(checkpoint):
|
def convert_ldm_clip_checkpoint(checkpoint):
|
||||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14",cache_dir=global_cache_dir('hub'))
|
||||||
|
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
|
|
||||||
@ -677,7 +677,8 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
|
|
||||||
|
|
||||||
def convert_paint_by_example_checkpoint(checkpoint):
|
def convert_paint_by_example_checkpoint(checkpoint):
|
||||||
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
cache_dir = global_cache_dir('hub')
|
||||||
|
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||||
model = PaintByExampleImageEncoder(config)
|
model = PaintByExampleImageEncoder(config)
|
||||||
|
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
@ -744,7 +745,8 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
|||||||
|
|
||||||
|
|
||||||
def convert_open_clip_checkpoint(checkpoint):
|
def convert_open_clip_checkpoint(checkpoint):
|
||||||
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
cache_dir=global_cache_dir('hub')
|
||||||
|
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir)
|
||||||
|
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
|
|
||||||
@ -795,6 +797,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
|||||||
):
|
):
|
||||||
|
|
||||||
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
||||||
|
cache_dir = global_cache_dir('hub')
|
||||||
|
|
||||||
# Sometimes models don't have the global_step item
|
# Sometimes models don't have the global_step item
|
||||||
if "global_step" in checkpoint:
|
if "global_step" in checkpoint:
|
||||||
@ -904,7 +907,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
|||||||
|
|
||||||
if model_type == "FrozenOpenCLIPEmbedder":
|
if model_type == "FrozenOpenCLIPEmbedder":
|
||||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer",cache_dir=global_cache_dir('diffusers'))
|
||||||
pipe = StableDiffusionPipeline(
|
pipe = StableDiffusionPipeline(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_model,
|
text_encoder=text_model,
|
||||||
@ -917,8 +920,8 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
|||||||
)
|
)
|
||||||
elif model_type == "PaintByExample":
|
elif model_type == "PaintByExample":
|
||||||
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||||
pipe = PaintByExamplePipeline(
|
pipe = PaintByExamplePipeline(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
image_encoder=vision_model,
|
image_encoder=vision_model,
|
||||||
@ -929,9 +932,9 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
|||||||
)
|
)
|
||||||
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
|
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
|
||||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||||
pipe = StableDiffusionPipeline(
|
pipe = StableDiffusionPipeline(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_model,
|
text_encoder=text_model,
|
||||||
@ -944,7 +947,7 @@ def convert_ckpt_to_diffuser(checkpoint_path:str,
|
|||||||
else:
|
else:
|
||||||
text_config = create_ldm_bert_config(original_config)
|
text_config = create_ldm_bert_config(original_config)
|
||||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
|
||||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
pipe.save_pretrained(
|
pipe.save_pretrained(
|
||||||
|
Loading…
Reference in New Issue
Block a user