InvokeAI/scripts/orig_scripts/train_searcher.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

170 lines
5.4 KiB
Python
Raw Normal View History

2023-08-17 22:45:25 +00:00
import os
import sys
2022-08-10 14:30:49 +00:00
import numpy as np
import scann
import argparse
import glob
from multiprocessing import cpu_count
from tqdm import tqdm
from ldm.util import parallel_data_prefetch
def search_bruteforce(searcher):
return searcher.score_brute_force().build()
def search_partioned_ah(
searcher, dims_per_block, aiq_threshold, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
):
2023-07-27 14:54:01 +00:00
return (
2022-08-10 14:30:49 +00:00
searcher.tree(
num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=partioning_trainsize
2023-07-27 14:54:01 +00:00
)
2022-08-10 14:30:49 +00:00
.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold)
.reorder(reorder_k)
.build()
2023-07-27 14:54:01 +00:00
)
2022-08-10 14:30:49 +00:00
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
return (
searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
)
def load_datapool(dpath):
def load_single_file(saved_embeddings):
compressed = np.load(saved_embeddings)
database = {key: compressed[key] for key in compressed.files}
return database
def load_multi_files(data_archive):
database = {key: [] for key in data_archive[0].files}
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
for key in d.files:
database[key].append(d[key])
return database
print(f'Load saved patch embedding from "{dpath}"')
file_content = glob.glob(os.path.join(dpath, "*.npz"))
if len(file_content) == 1:
data_pool = load_single_file(file_content[0])
elif len(file_content) > 1:
data = [np.load(f) for f in file_content]
prefetched_data = parallel_data_prefetch(
load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
)
data_pool = {
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()
}
else:
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
return data_pool
def train_searcher(
opt,
metric="dot_product",
partioning_trainsize=None,
reorder_k=None,
# todo tune
aiq_thld=0.2,
dims_per_block=2,
num_leaves=None,
num_leaves_to_search=None,
):
data_pool = load_datapool(opt.database)
k = opt.knn
if not reorder_k:
reorder_k = 2 * k
# normalize
# embeddings =
searcher = scann.scann_ops_pybind.builder(
data_pool["embedding"] / np.linalg.norm(data_pool["embedding"], axis=1)[:, np.newaxis], k, metric
)
pool_size = data_pool["embedding"].shape[0]
2023-07-27 14:54:01 +00:00
2022-08-10 14:30:49 +00:00
print(*(["#"] * 100))
print("Initializing scaNN searcher with the following values:")
print(f"k: {k}")
print(f"metric: {metric}")
print(f"reorder_k: {reorder_k}")
print(f"anisotropic_quantization_threshold: {aiq_thld}")
print(f"dims_per_block: {dims_per_block}")
print(*(["#"] * 100))
print("Start training searcher....")
print(f"N samples in pool is {pool_size}")
# this reflects the recommended design choices proposed at
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
if pool_size < 2e4:
print("Using brute force search.")
searcher = search_bruteforce(searcher)
elif 2e4 <= pool_size and pool_size < 1e5:
print("Using asymmetric hashing search and reordering.")
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
else:
print("Using using partioning, asymmetric hashing search and reordering.")
if not partioning_trainsize:
partioning_trainsize = data_pool["embedding"].shape[0] // 10
if not num_leaves:
num_leaves = int(np.sqrt(pool_size))
if not num_leaves_to_search:
num_leaves_to_search = max(num_leaves // 20, 1)
print("Partitioning params:")
print(f"num_leaves: {num_leaves}")
print(f"num_leaves_to_search: {num_leaves_to_search}")
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
searcher = search_partioned_ah(
searcher, dims_per_block, aiq_thld, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
)
print("Finish training searcher")
searcher_savedir = opt.target_path
os.makedirs(searcher_savedir, exist_ok=True)
searcher.serialize(searcher_savedir)
print(f'Saved trained searcher under "{searcher_savedir}"')
2023-07-27 14:54:01 +00:00
2022-08-10 14:30:49 +00:00
if __name__ == "__main__":
sys.path.append(os.getcwd())
parser = argparse.ArgumentParser()
parser.add_argument(
"--database",
"-d",
default="data/rdm/retrieval_databases/openimages",
type=str,
help="path to folder containing the clip feature of the database",
)
parser.add_argument(
"--target_path",
"-t",
default="data/rdm/searchers/openimages",
type=str,
help="path to the target folder where the searcher shall be stored.",
)
parser.add_argument(
"--knn",
"-k",
default=20,
type=int,
help="number of nearest neighbors, for which the searcher shall be optimized",
)
2023-07-27 14:54:01 +00:00
2022-08-10 14:30:49 +00:00
opt, _ = parser.parse_known_args()
2023-07-27 14:54:01 +00:00
2022-08-10 14:30:49 +00:00
train_searcher(
opt,
)