load embeddings after a ckpt legacy model is converted to diffusers

- Fixes #2954
- Also improves diagnostic reporting during embedding loading.
This commit is contained in:
Lincoln Stein 2023-03-23 15:21:58 -04:00
parent 485f6e5954
commit f751dcd245
2 changed files with 4 additions and 6 deletions

View File

@ -362,6 +362,7 @@ class ModelManager(object):
raise NotImplementedError( raise NotImplementedError(
f"Unknown model format {model_name}: {model_format}" f"Unknown model format {model_name}: {model_format}"
) )
self._add_embeddings_to_model(model)
# usage statistics # usage statistics
toc = time.time() toc = time.time()
@ -436,7 +437,6 @@ class ModelManager(object):
height = width height = width
print(f" | Default image dimensions = {width} x {height}") print(f" | Default image dimensions = {width} x {height}")
self._add_embeddings_to_model(pipeline)
return pipeline, width, height, model_hash return pipeline, width, height, model_hash

View File

@ -6,7 +6,6 @@ The interface is through the Concepts() object.
""" """
import os import os
import re import re
import traceback
from typing import Callable from typing import Callable
from urllib import error as ul_error from urllib import error as ul_error
from urllib import request from urllib import request
@ -15,7 +14,6 @@ from huggingface_hub import (
HfApi, HfApi,
HfFolder, HfFolder,
ModelFilter, ModelFilter,
ModelSearchArguments,
hf_hub_url, hf_hub_url,
) )
@ -84,7 +82,7 @@ class HuggingFaceConceptsLibrary(object):
""" """
if not concept_name in self.list_concepts(): if not concept_name in self.list_concepts():
print( print(
f"This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept." f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
) )
return None return None
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin") return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
@ -236,7 +234,7 @@ class HuggingFaceConceptsLibrary(object):
except ul_error.HTTPError as e: except ul_error.HTTPError as e:
if e.code == 404: if e.code == 404:
print( print(
f"This concept is not known to the Hugging Face library. Generation will continue without the concept." f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
) )
else: else:
print( print(
@ -246,7 +244,7 @@ class HuggingFaceConceptsLibrary(object):
return False return False
except ul_error.URLError as e: except ul_error.URLError as e:
print( print(
f"ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept." f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
) )
os.rmdir(dest) os.rmdir(dest)
return False return False