mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model_cache: let offload_model work with DiffusionPipeline, sorta.
This commit is contained in:
parent
95db6e80ee
commit
01ff1cff58
@ -4,6 +4,7 @@ They are moved between GPU and CPU as necessary. If CPU memory falls
|
|||||||
below a preset minimum, the least recently used model will be
|
below a preset minimum, the least recently used model will be
|
||||||
cleared and loaded from disk when next needed.
|
cleared and loaded from disk when next needed.
|
||||||
'''
|
'''
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -413,10 +414,13 @@ class ModelCache(object):
|
|||||||
|
|
||||||
def _model_to_cpu(self,model):
|
def _model_to_cpu(self,model):
|
||||||
if self.device != 'cpu':
|
if self.device != 'cpu':
|
||||||
|
try:
|
||||||
model.cond_stage_model.device = 'cpu'
|
model.cond_stage_model.device = 'cpu'
|
||||||
model.first_stage_model.to('cpu')
|
model.first_stage_model.to('cpu')
|
||||||
model.cond_stage_model.to('cpu')
|
model.cond_stage_model.to('cpu')
|
||||||
model.model.to('cpu')
|
model.model.to('cpu')
|
||||||
|
except AttributeError as e:
|
||||||
|
warnings.warn(f"TODO: clean up legacy model-management: {e}")
|
||||||
return model.to('cpu')
|
return model.to('cpu')
|
||||||
else:
|
else:
|
||||||
return model
|
return model
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
import os.path
|
import os.path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -235,13 +237,15 @@ class SpatialRescaler(nn.Module):
|
|||||||
|
|
||||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||||
|
tokenizer: CLIPTokenizer
|
||||||
|
transformer: CLIPTextModel
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
version='openai/clip-vit-large-patch14',
|
version:str='openai/clip-vit-large-patch14',
|
||||||
max_length=77,
|
max_length:int=77,
|
||||||
tokenizer=None,
|
tokenizer:Optional[CLIPTokenizer]=None,
|
||||||
transformer=None,
|
transformer:Optional[CLIPTextModel]=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
cache = os.path.join(Globals.root,'models',version)
|
cache = os.path.join(Globals.root,'models',version)
|
||||||
@ -464,6 +468,10 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
def device(self):
|
def device(self):
|
||||||
return self.transformer.device
|
return self.transformer.device
|
||||||
|
|
||||||
|
@device.setter
|
||||||
|
def device(self, device):
|
||||||
|
self.transformer.to(device=device)
|
||||||
|
|
||||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||||
|
|
||||||
fragment_weights_key = "fragment_weights"
|
fragment_weights_key = "fragment_weights"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user