diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 0103e3af55..8c6b23944c 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -65,23 +65,20 @@ class CompelInvocation(BaseInvocation): **self.clip.text_encoder.dict(), ) with tokenizer_info as orig_tokenizer,\ - text_encoder_info as text_encoder,\ - ExitStack() as stack: + text_encoder_info as text_encoder: - loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras] + loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: ti_list.append( - stack.enter_context( - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ) - ) + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + ).context.model ) except Exception: #print(e) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 1d7f218bfc..e6392271b1 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -362,8 +362,7 @@ class TextToLatentsInvocation(BaseInvocation): self.dispatch_progress(context, source_node_id, state) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) - with unet_info as unet,\ - ExitStack() as stack: + with unet_info as unet: scheduler = get_scheduler( context=context, @@ -374,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation): pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler) - loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] + loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, @@ -438,8 +437,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): **self.unet.unet.dict(), ) - with unet_info as unet,\ - ExitStack() as stack: + with unet_info as unet: scheduler = get_scheduler( context=context, @@ -468,7 +466,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): device=unet.device, ) - loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] + loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] with ModelPatcher.apply_lora_unet(pipeline.unet, loras): result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 7490414bce..760fa08a12 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -177,9 +177,13 @@ class LoraLoaderInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + # TODO: ui rewrite + base_model = BaseModelType.StableDiffusion1 + if not context.services.model_manager.model_exists( + base_model=base_model, model_name=self.lora_name, - model_type=SDModelType.Lora, + model_type=ModelType.Lora, ): raise Exception(f"Unkown lora name: {self.lora_name}!") @@ -195,8 +199,9 @@ class LoraLoaderInvocation(BaseInvocation): output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( + base_model=base_model, model_name=self.lora_name, - model_type=SDModelType.Lora, + model_type=ModelType.Lora, submodel=None, weight=self.weight, ) @@ -206,8 +211,9 @@ class LoraLoaderInvocation(BaseInvocation): output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( + base_model=base_model, model_name=self.lora_name, - model_type=SDModelType.Lora, + model_type=ModelType.Lora, submodel=None, weight=self.weight, ) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index c351a76590..7f0f3985c0 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -70,7 +70,7 @@ class LoRALayerBase: op = torch.nn.functional.linear extra_args = {} - weight = self.get_weight(module) + weight = self.get_weight() bias = self.bias if self.bias is not None else 0 scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 @@ -81,7 +81,7 @@ class LoRALayerBase: **extra_args, ) * multiplier * scale - def get_weight(self, module: torch.nn.Module): + def get_weight(self): raise NotImplementedError() def calc_size(self) -> int: @@ -122,7 +122,7 @@ class LoRALayer(LoRALayerBase): self.rank = self.down.shape[0] - def get_weight(self, module: torch.nn.Module): + def get_weight(self): if self.mid is not None: up = self.up.reshape(up.shape[0], up.shape[1]) down = self.down.reshape(up.shape[0], up.shape[1]) @@ -185,7 +185,7 @@ class LoHALayer(LoRALayerBase): self.rank = self.w1_b.shape[0] - def get_weight(self, module: torch.nn.Module): + def get_weight(self): if self.t1 is None: weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) @@ -271,7 +271,7 @@ class LoKRLayer(LoRALayerBase): else: self.rank = None # unscaled - def get_weight(self, module: torch.nn.Module): + def get_weight(self): w1 = self.w1 if w1 is None: w1 = self.w1_a @ self.w1_b @@ -286,7 +286,7 @@ class LoKRLayer(LoRALayerBase): if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) w2 = w2.contiguous() - weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape? + weight = torch.kron(w1, w2) return weight @@ -471,7 +471,7 @@ class ModelPatcher: submodule_name += "_" + key_parts.pop(0) module = module.get_submodule(submodule_name) - module_key = module_key.rstrip(".") + module_key = (module_key + "." + submodule_name).lstrip(".") return (module_key, module) @@ -525,23 +525,36 @@ class ModelPatcher: loras: List[Tuple[LoraModel, float]], prefix: str, ): - hooks = dict() + original_weights = dict() try: - for lora, lora_weight in loras: - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue + with torch.no_grad(): + for lora, lora_weight in loras: + #assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue - module_key, module = cls._resolve_lora_key(model, layer_key, prefix) - if module_key not in hooks: - hooks[module_key] = module.register_forward_hook(cls._lora_forward_hook(loras, layer_key)) + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + if module_key not in original_weights: + original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) + + # enable autocast to calc fp16 loras on cpu + with torch.autocast(device_type="cpu"): + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + layer_weight = layer.get_weight() * lora_weight * layer_scale + + if module.weight.shape != layer_weight.shape: + # TODO: debug on lycoris + layer_weight = layer_weight.reshape(module.weight.shape) + + module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype) yield # wait for context manager exit finally: - for module_key, hook in hooks.items(): - hook.remove() - hooks.clear() + with torch.no_grad(): + for module_key, weight in original_weights.items(): + model.get_submodule(module_key).weight.copy_(weight) @classmethod @@ -591,7 +604,7 @@ class ModelPatcher: f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}." ) - model_embeddings.weight.data[token_id] = embedding + model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype) ti_tokens.append(token_id) if len(ti_tokens) > 1: