model_manager: model to/from CPU methods are implemented on the Pipeline

This commit is contained in:
Kevin Turner 2023-03-09 17:59:55 -08:00
parent 42355b70c2
commit ad7b1fa6fb

View File

@ -104,7 +104,7 @@ class ModelManager(object):
if model_name in self.models: if model_name in self.models:
requested_model = self.models[model_name]["model"] requested_model = self.models[model_name]["model"]
print(f">> Retrieving model {model_name} from system RAM cache") print(f">> Retrieving model {model_name} from system RAM cache")
self.models[model_name]["model"] = self._model_from_cpu(requested_model) requested_model.ready()
width = self.models[model_name]["width"] width = self.models[model_name]["width"]
height = self.models[model_name]["height"] height = self.models[model_name]["height"]
hash = self.models[model_name]["hash"] hash = self.models[model_name]["hash"]
@ -499,6 +499,7 @@ class ModelManager(object):
print(f">> Offloading {model_name} to CPU") print(f">> Offloading {model_name} to CPU")
model = self.models[model_name]["model"] model = self.models[model_name]["model"]
model.offload_all()
self.models[model_name]["model"] = self._model_to_cpu(model) self.models[model_name]["model"] = self._model_to_cpu(model)
gc.collect() gc.collect()
@ -1044,43 +1045,6 @@ class ModelManager(object):
self.stack.remove(model_name) self.stack.remove(model_name)
self.models.pop(model_name, None) self.models.pop(model_name, None)
def _model_to_cpu(self, model):
if self.device == CPU_DEVICE:
return model
if isinstance(model, StableDiffusionGeneratorPipeline):
model.offload_all()
return model
model.cond_stage_model.device = CPU_DEVICE
model.to(CPU_DEVICE)
for submodel in ("first_stage_model", "cond_stage_model", "model"):
try:
getattr(model, submodel).to(CPU_DEVICE)
except AttributeError:
pass
return model
def _model_from_cpu(self, model):
if self.device == CPU_DEVICE:
return model
if isinstance(model, StableDiffusionGeneratorPipeline):
model.ready()
return model
model.to(self.device)
model.cond_stage_model.device = self.device
for submodel in ("first_stage_model", "cond_stage_model", "model"):
try:
getattr(model, submodel).to(self.device)
except AttributeError:
pass
return model
def _pop_oldest_model(self): def _pop_oldest_model(self):
""" """
Remove the first element of the FIFO, which ought Remove the first element of the FIFO, which ought