mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix - apply precision to text_encoder
This commit is contained in:
parent
039fa73269
commit
1f602e6143
@ -112,7 +112,7 @@ SIZE_GUESSTIMATE = {
|
|||||||
|
|
||||||
# The list of model classes we know how to fetch, for typechecking
|
# The list of model classes we know how to fetch, for typechecking
|
||||||
ModelClass = Union[tuple([x for x in MODEL_CLASSES.values()])]
|
ModelClass = Union[tuple([x for x in MODEL_CLASSES.values()])]
|
||||||
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, EmptyScheduler, UNet2DConditionModel)
|
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, EmptyScheduler, UNet2DConditionModel, CLIPTextModel)
|
||||||
|
|
||||||
class UnsafeModelException(Exception):
|
class UnsafeModelException(Exception):
|
||||||
"Raised when a legacy model file fails the picklescan test"
|
"Raised when a legacy model file fails the picklescan test"
|
||||||
@ -320,7 +320,7 @@ class ModelCache(object):
|
|||||||
if model.device != cache.execution_device:
|
if model.device != cache.execution_device:
|
||||||
cache.logger.debug(f'Moving {key} into {cache.execution_device}')
|
cache.logger.debug(f'Moving {key} into {cache.execution_device}')
|
||||||
with VRAMUsage() as mem:
|
with VRAMUsage() as mem:
|
||||||
model.to(cache.execution_device, dtype=cache.precision) # move into GPU
|
model.to(cache.execution_device) # move into GPU
|
||||||
cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
||||||
cache.model_sizes[key] = mem.vram_used # more accurate size
|
cache.model_sizes[key] = mem.vram_used # more accurate size
|
||||||
|
|
||||||
@ -534,8 +534,11 @@ class ModelCache(object):
|
|||||||
|
|
||||||
extra_args = dict()
|
extra_args = dict()
|
||||||
if model_class in DiffusionClasses:
|
if model_class in DiffusionClasses:
|
||||||
extra_args = dict(
|
extra_args.update(
|
||||||
torch_dtype=self.precision,
|
torch_dtype=self.precision,
|
||||||
|
)
|
||||||
|
if model_class == StableDiffusionGeneratorPipeline:
|
||||||
|
extra_args.update(
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user