diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index b5dc775c73..c67d02bb6c 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -494,7 +494,7 @@ The vast majority of these arguments default to reasonable values. sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) - model.cuda() + model.to(self.device) model.eval() if self.full_precision: print('Using slower but more accurate full-precision math (--full_precision)')