mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
final tidying before marking PR as ready for review
- Replace AnyModelLoader with ModelLoaderRegistry - Fix type check errors in multiple files - Remove apparently unneeded `get_model_config_enum()` method from model manager - Remove last vestiges of old model manager - Updated tests and documentation resolve conflict with seamless.py
This commit is contained in:
@ -35,7 +35,7 @@ class Struct_mallinfo2(ctypes.Structure):
|
||||
("keepcost", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
s = ""
|
||||
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
|
||||
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
|
||||
@ -62,7 +62,7 @@ class LibcUtil:
|
||||
TODO: Improve cross-OS compatibility of this class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
|
||||
|
||||
def mallinfo2(self) -> Struct_mallinfo2:
|
||||
@ -72,4 +72,5 @@ class LibcUtil:
|
||||
"""
|
||||
mallinfo2 = self._libc.mallinfo2
|
||||
mallinfo2.restype = Struct_mallinfo2
|
||||
return mallinfo2()
|
||||
result: Struct_mallinfo2 = mallinfo2()
|
||||
return result
|
||||
|
@ -1,12 +1,15 @@
|
||||
"""Utilities for parsing model files, used mostly by probe.py"""
|
||||
|
||||
import json
|
||||
import torch
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
def _fast_safetensors_reader(path: str):
|
||||
|
||||
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
|
||||
checkpoint = {}
|
||||
device = torch.device("meta")
|
||||
with open(path, "rb") as f:
|
||||
@ -37,10 +40,12 @@ def _fast_safetensors_reader(path: str):
|
||||
|
||||
return checkpoint
|
||||
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]:
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
checkpoint = _fast_safetensors_reader(path)
|
||||
path_str = path.as_posix() if isinstance(path, Path) else path
|
||||
checkpoint = _fast_safetensors_reader(path_str)
|
||||
except Exception:
|
||||
# TODO: create issue for support "meta"?
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||
@ -52,14 +57,15 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
|
||||
def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
|
||||
"""
|
||||
Given a checkpoint in memory, return the lora token vector length
|
||||
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key: str, tensor, checkpoint) -> int:
|
||||
def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
|
Reference in New Issue
Block a user