浅谈sparse vec检索工程化实现 - JadePeng
source link: https://www.cnblogs.com/xiaoqi/p/18150639/golang-sparse-retrival
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
前面我们通过两篇文章: BGE M3-Embedding 模型介绍 和 Sparse稀疏检索介绍与实践 介绍了sparse 稀疏检索,今天我们来看看如何建立一个工程化的系统来实现sparse vec的检索。
之前提过milvus最新的V2.4支持sparse检索,我们先看看milvus的实现。
milvus的sparse检索实现
milvus 检索底层引擎是knowhere,主要代码在src/index/sparse
里。
首先,通过数据结构SparseRow,用于表示稀疏向量,支持浮点数(float)类型的数据
class SparseRow {
static_assert(std::is_same_v<T, fp32>, "SparseRow supports float only");
public:
// construct an SparseRow with memory allocated to hold `count` elements.
SparseRow(size_t count = 0)
: data_(count ? new uint8_t[count * element_size()] : nullptr), count_(count), own_data_(true) {
}
SparseRow(size_t count, uint8_t* data, bool own_data) : data_(data), count_(count), own_data_(own_data) {
}
// copy constructor and copy assignment operator perform deep copy
SparseRow(const SparseRow<T>& other) : SparseRow(other.count_) {
std::memcpy(data_, other.data_, data_byte_size());
}
SparseRow(SparseRow<T>&& other) noexcept : SparseRow() {
swap(*this, other);
}
SparseRow&
operator=(const SparseRow<T>& other) {
if (this != &other) {
SparseRow<T> tmp(other);
swap(*this, tmp);
}
return *this;
}
SparseRow&
operator=(SparseRow<T>&& other) noexcept {
swap(*this, other);
return *this;
}
~SparseRow() {
if (own_data_ && data_ != nullptr) {
delete[] data_;
data_ = nullptr;
}
}
size_t
size() const {
return count_;
}
size_t
memory_usage() const {
return data_byte_size() + sizeof(*this);
}
// return the number of bytes used by the underlying data array.
size_t
data_byte_size() const {
return count_ * element_size();
}
void*
data() {
return data_;
}
const void*
data() const {
return data_;
}
// dim of a sparse vector is the max index + 1, or 0 for an empty vector.
int64_t
dim() const {
if (count_ == 0) {
return 0;
}
auto* elem = reinterpret_cast<const ElementProxy*>(data_) + count_ - 1;
return elem->index + 1;
}
SparseIdVal<T>
operator[](size_t i) const {
auto* elem = reinterpret_cast<const ElementProxy*>(data_) + i;
return {elem->index, elem->value};
}
void
set_at(size_t i, table_t index, T value) {
auto* elem = reinterpret_cast<ElementProxy*>(data_) + i;
elem->index = index;
elem->value = value;
}
float
dot(const SparseRow<T>& other) const {
float product_sum = 0.0f;
size_t i = 0;
size_t j = 0;
// TODO: improve with _mm_cmpistrm or the AVX512 alternative.
while (i < count_ && j < other.count_) {
auto* left = reinterpret_cast<const ElementProxy*>(data_) + i;
auto* right = reinterpret_cast<const ElementProxy*>(other.data_) + j;
if (left->index < right->index) {
++i;
} else if (left->index > right->index) {
++j;
} else {
product_sum += left->value * right->value;
++i;
++j;
}
}
return product_sum;
}
friend void
swap(SparseRow<T>& left, SparseRow<T>& right) {
using std::swap;
swap(left.count_, right.count_);
swap(left.data_, right.data_);
swap(left.own_data_, right.own_data_);
}
static inline size_t
element_size() {
return sizeof(table_t) + sizeof(T);
}
private:
// ElementProxy is used to access elements in the data_ array and should
// never be actually constructed.
struct __attribute__((packed)) ElementProxy {
table_t index;
T value;
ElementProxy() = delete;
ElementProxy(const ElementProxy&) = delete;
};
// data_ must be sorted by column id. use raw pointer for easy mmap and zero
// copy.
uint8_t* data_;
size_t count_;
bool own_data_;
};
然后索引具体是在InvertedIndex
类里, 对应sparse_inverted_index.h
文件,首先看定义的一些private 字段。
std::vector<SparseRow<T>> raw_data_;
mutable std::shared_mutex mu_;
std::unordered_map<table_t, std::vector<SparseIdVal<T>>> inverted_lut_;
bool use_wand_ = false;
// If we want to drop small values during build, we must first train the
// index with all the data to compute value_threshold_.
bool drop_during_build_ = false;
// when drop_during_build_ is true, any value smaller than value_threshold_
// will not be added to inverted_lut_. value_threshold_ is set to the
// drop_ratio_build-th percentile of all absolute values in the index.
T value_threshold_ = 0.0f;
std::unordered_map<table_t, T> max_in_dim_;
size_t max_dim_ = 0;
- raw_data_ 是原始的数据
- inverted_lut_ 可以理解为一个倒排表
- use_wand_ 用于控制查询时,是否使用WAND算法,WAND算法是经典的查询优化算法,可以通过类似跳表的方式跳过一些数据,减少计算量,提升查询效率
- max_in_dim_ 是为wand服务的
索引构建流程
构建,主要是对外提供一个Add数据的方法:
Status
Add(const SparseRow<T>* data, size_t rows, int64_t dim) {
std::unique_lock<std::shared_mutex> lock(mu_);
auto current_rows = n_rows_internal();
if (current_rows > 0 && drop_during_build_) {
LOG_KNOWHERE_ERROR_ << "Not allowed to add data to a built index with drop_ratio_build > 0.";
return Status::invalid_args;
}
if ((size_t)dim > max_dim_) {
max_dim_ = dim;
}
raw_data_.insert(raw_data_.end(), data, data + rows);
for (size_t i = 0; i < rows; ++i) {
add_row_to_index(data[i], current_rows + i);
}
return Status::success;
}
这里会更新数据的max_dim,数据追加到raw_data_,然后add_row_to_index,将新的doc放入inverted_lut_, 并更新max_in_dim_,用于记录最大值,方便wand查询时跳过计算。
inline void
add_row_to_index(const SparseRow<T>& row, table_t id) {
for (size_t j = 0; j < row.size(); ++j) {
auto [idx, val] = row[j];
// Skip values close enough to zero(which contributes little to
// the total IP score).
if (drop_during_build_ && fabs(val) < value_threshold_) {
continue;
}
if (inverted_lut_.find(idx) == inverted_lut_.end()) {
inverted_lut_[idx];
if (use_wand_) {
max_in_dim_[idx] = 0;
}
}
inverted_lut_[idx].emplace_back(id, val);
if (use_wand_) {
max_in_dim_[idx] = std::max(max_in_dim_[idx], val);
}
}
}
索引保存与load
保存时,是自定义的二进制文件:
Status
Save(MemoryIOWriter& writer) {
/**
* zero copy is not yet implemented, now serializing in a zero copy
* compatible way while still copying during deserialization.
*
* Layout:
*
* 1. int32_t rows, sign indicates whether to use wand
* 2. int32_t cols
* 3. for each row:
* 1. int32_t len
* 2. for each non-zero value:
* 1. table_t idx
* 2. T val
* With zero copy deserization, each SparseRow object should
* reference(not owning) the memory address of the first element.
*
* inverted_lut_ and max_in_dim_ not serialized, they will be
* constructed dynamically during deserialization.
*
* Data are densly packed in serialized bytes and no padding is added.
*/
std::shared_lock<std::shared_mutex> lock(mu_);
writeBinaryPOD(writer, n_rows_internal() * (use_wand_ ? 1 : -1));
writeBinaryPOD(writer, n_cols_internal());
writeBinaryPOD(writer, value_threshold_);
for (size_t i = 0; i < n_rows_internal(); ++i) {
auto& row = raw_data_[i];
writeBinaryPOD(writer, row.size());
if (row.size() == 0) {
continue;
}
writer.write(row.data(), row.size() * SparseRow<T>::element_size());
}
return Status::success;
}
索引文件格式:
-
- int32_t rows 总记录数,通过±符号来区分是否 use wand
-
- int32_t cols 列数
-
- for each row:
-
1. int32_t len 长度
-
2. for each non-zero value:
-
1. table_t idx term的id编号
-
2. T val term的权重
注意,这里inverted_lut_
倒排表是没有存储的,是在加载的时候重建,所以load的过程,就是一个逆过程:
Status
Load(MemoryIOReader& reader) {
std::unique_lock<std::shared_mutex> lock(mu_);
int64_t rows;
readBinaryPOD(reader, rows);
use_wand_ = rows > 0;
rows = std::abs(rows);
readBinaryPOD(reader, max_dim_);
readBinaryPOD(reader, value_threshold_);
raw_data_.reserve(rows);
for (int64_t i = 0; i < rows; ++i) {
size_t count;
readBinaryPOD(reader, count);
raw_data_.emplace_back(count);
if (count == 0) {
continue;
}
reader.read(raw_data_[i].data(), count * SparseRow<T>::element_size());
add_row_to_index(raw_data_[i], i);
}
return Status::success;
}
我们来回顾,compute_lexical_matching_score其实就是计算共同term的weight score相乘,然后加起来,所以可以想象下,暴力检索大概就是把所有term对应的doc取并集,然后计算lexical_matching_score,最后取topk。
我们来看milvus的实现,先看暴力检索:
// find the top-k candidates using brute force search, k as specified by the capacity of the heap.
// any value in q_vec that is smaller than q_threshold and any value with dimension >= n_cols() will be ignored.
// TODO: may switch to row-wise brute force if filter rate is high. Benchmark needed.
void
search_brute_force(const SparseRow<T>& q_vec, T q_threshold, MaxMinHeap<T>& heap, const BitsetView& bitset) const {
auto scores = compute_all_distances(q_vec, q_threshold);
for (size_t i = 0; i < n_rows_internal(); ++i) {
if ((bitset.empty() || !bitset.test(i)) && scores[i] != 0) {
heap.push(i, scores[i]);
}
}
}
std::vector<float>
compute_all_distances(const SparseRow<T>& q_vec, T q_threshold) const {
std::vector<float> scores(n_rows_internal(), 0.0f);
for (size_t idx = 0; idx < q_vec.size(); ++idx) {
auto [i, v] = q_vec[idx];
if (v < q_threshold || i >= n_cols_internal()) {
continue;
}
auto lut_it = inverted_lut_.find(i);
if (lut_it == inverted_lut_.end()) {
continue;
}
// TODO: improve with SIMD
auto& lut = lut_it->second;
for (size_t j = 0; j < lut.size(); j++) {
auto [idx, val] = lut[j];
scores[idx] += v * float(val);
}
}
return scores;
}
- 核心在
compute_all_distances
里,先通过q_vec得到每一个term id,然后从inverted_lut_里找到term对应的doc list,然后计算score,相同doc id的score累加 - 最后用MaxMinHeap堆,来取topk
暴力检索能保准精准性,但是效率比较低。我们来看使用wand优化的检索:
// any value in q_vec that is smaller than q_threshold will be ignored.
void
search_wand(const SparseRow<T>& q_vec, T q_threshold, MaxMinHeap<T>& heap, const BitsetView& bitset) const {
auto q_dim = q_vec.size();
std::vector<std::shared_ptr<Cursor<std::vector<SparseIdVal<T>>>>> cursors(q_dim);
auto valid_q_dim = 0;
// 倒排链
for (size_t i = 0; i < q_dim; ++i) {
// idx(term_id)
auto [idx, val] = q_vec[i];
if (std::abs(val) < q_threshold || idx >= n_cols_internal()) {
continue;
}
auto lut_it = inverted_lut_.find(idx);
if (lut_it == inverted_lut_.end()) {
continue;
}
auto& lut = lut_it->second;
// max_in_dim_ 记录了term index 的最大score
cursors[valid_q_dim++] = std::make_shared<Cursor<std::vector<SparseIdVal<T>>>>(
lut, n_rows_internal(), max_in_dim_.find(idx)->second * val, val, bitset);
}
if (valid_q_dim == 0) {
return;
}
cursors.resize(valid_q_dim);
auto sort_cursors = [&cursors] {
std::sort(cursors.begin(), cursors.end(),
[](auto& x, auto& y) { return x->cur_vec_id() < y->cur_vec_id(); });
};
sort_cursors();
// 堆未满,或者新的score > 堆顶的score
auto score_above_threshold = [&heap](float x) { return !heap.full() || x > heap.top().val; };
while (true) {
// 上边界
float upper_bound = 0;
// pivot 满足条件的倒排链的序号
size_t pivot;
bool found_pivot = false;
for (pivot = 0; pivot < cursors.size(); ++pivot) {
// 有倒排结束
if (cursors[pivot]->is_end()) {
break;
}
upper_bound += cursors[pivot]->max_score();
if (score_above_threshold(upper_bound)) {
found_pivot = true;
break;
}
}
if (!found_pivot) {
break;
}
// 找到满足upper_bound 满足条件的pivot_id
table_t pivot_id = cursors[pivot]->cur_vec_id();
// 如果第一个倒排链的当前vec_id (doc_id) 等于pivot_id,可以直接从第0个倒排链开始,计算score
if (pivot_id == cursors[0]->cur_vec_id()) {
float score = 0;
// 遍历所有cursors,累加score
for (auto& cursor : cursors) {
if (cursor->cur_vec_id() != pivot_id) {
break;
}
score += cursor->cur_distance() * cursor->q_value();
// 倒排链移到下一位
cursor->next();
}
// 放入堆
heap.push(pivot_id, score);
// 重排cursors,保证最小的vec_id在最前面
sort_cursors();
} else {
// 第一个倒排链的当前vec_id不等于pivot_id, pivot>=1
// 那么从pivot(满足threshold的倒排链序号)往前找是否有cur_vec_id==pivot_id的
size_t next_list = pivot;
for (; cursors[next_list]->cur_vec_id() == pivot_id; --next_list) {
}
// 这里的next_list的cur_vec_id 不一定等与pivot_id,将list seek到pivot_id
// seek后,cursors[next_list].cur_vec_id() >= pivot_id,通过seek,可以跳过一些vec id
cursors[next_list]->seek(pivot_id);
// 从next_list + 1开始
for (size_t i = next_list + 1; i < cursors.size(); ++i) {
// 如果当前cur_vec_id >= 上一个则停止
if (cursors[i]->cur_vec_id() >= cursors[i - 1]->cur_vec_id()) {
break;
}
// 否则,交换倒排链,可以确保==pivot_id的倒排链交换到前面
std::swap(cursors[i], cursors[i - 1]);
}
}
}
}
- 首先是倒排链取出来放入cursors,然后对cursors按照vec_id排序,将vec_id较小的排到倒排链的首位
- 通过score_above_threshold,遍历cursors找符合条件的cursor 索引号pivot,这里通过
堆未满,或者新的score > 堆顶的score
来判断,可以跳过一些score小的 - 然后找到pivot cursor对应的pivot_id,也就是doc id,然后判断第一个倒排链的cur_vec_id 是否等于pivot_id:
- 如果等于,就可以遍历倒排链,计算pivot_id的score,然后放入小顶堆中排序,然后重排倒排链
- 如果不等于,那么就需要想办法将cur_vec_id == pivot_id的往前放,同时跳过倒排链中vec_id < cur_vec_id的数据(减枝)
用golang实现轻量级sparse vec检索
用类似milvus的方法,我们简单实现一个golang版本的
package main
import (
"container/heap"
"encoding/binary"
"fmt"
"io"
"math/rand"
"os"
"sort"
"time"
)
type Cursor struct {
docIDs []int32
weights []float64
maxScore float64
termWeight float64
currentIdx int
}
func NewCursor(docIDs []int32, weights []float64, maxScore float64, weight float64) *Cursor {
return &Cursor{
docIDs: docIDs,
weights: weights,
maxScore: maxScore,
termWeight: weight,
currentIdx: 0,
}
}
func (c *Cursor) Next() {
c.currentIdx++
}
func (c *Cursor) Seek(docId int32) {
for {
if c.IsEnd() {
break
}
if c.CurrentDocID() < docId {
c.Next()
} else {
break
}
}
}
func (c *Cursor) IsEnd() bool {
return c.currentIdx >= len(c.docIDs)
}
func (c *Cursor) CurrentDocID() int32 {
return c.docIDs[c.currentIdx]
}
func (c *Cursor) CurrentDocWeight() float64 {
return c.weights[c.currentIdx]
}
// DocVectors type will map docID to its vector
type DocVectors map[int32]map[int32]float64
// InvertedIndex type will map termID to sorted list of docIDs
type InvertedIndex map[int32][]int32
// TermMaxScore will keep track of maximum scores for terms
type TermMaxScores map[int32]float64
// SparseIndex class struct
type SparseIndex struct {
docVectors DocVectors
invertedIndex InvertedIndex
termMaxScores TermMaxScores
dim int32
}
// NewSparseIndex initializes a new SparseIndex with empty structures
func NewSparseIndex() *SparseIndex {
return &SparseIndex{
docVectors: make(DocVectors),
invertedIndex: make(InvertedIndex),
termMaxScores: make(TermMaxScores),
dim: 0,
}
}
// Add method for adding documents to the sparse index
func (index *SparseIndex) Add(docID int32, vec map[int32]float64) {
index.docVectors[docID] = vec
for termID, score := range vec {
index.invertedIndex[termID] = append(index.invertedIndex[termID], docID)
// Track max score for each term
if maxScore, ok := index.termMaxScores[termID]; !ok || score > maxScore {
index.termMaxScores[termID] = score
}
if termID > index.dim {
index.dim = termID
}
}
}
// Save index to file
func (index *SparseIndex) Save(filename string) error {
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
// Write the dimension
binary.Write(file, binary.LittleEndian, index.dim)
// Write each document vector
for docID, vec := range index.docVectors {
binary.Write(file, binary.LittleEndian, docID)
vecSize := int32(len(vec))
binary.Write(file, binary.LittleEndian, vecSize)
for termID, score := range vec {
binary.Write(file, binary.LittleEndian, termID)
binary.Write(file, binary.LittleEndian, score)
}
}
return nil
}
// Load index from file
func (index *SparseIndex) Load(filename string) error {
file, err := os.Open(filename)
if err != nil {
return err
}
defer file.Close()
var dim int32
binary.Read(file, binary.LittleEndian, &dim)
index.dim = dim
for {
var docID int32
err := binary.Read(file, binary.LittleEndian, &docID)
if err == io.EOF {
break // End of file
} else if err != nil {
return err // Some other error
}
var vecSize int32
binary.Read(file, binary.LittleEndian, &vecSize)
vec := make(map[int32]float64)
for i := int32(0); i < vecSize; i++ {
var termID int32
var score float64
binary.Read(file, binary.LittleEndian, &termID)
binary.Read(file, binary.LittleEndian, &score)
vec[termID] = score
}
index.Add(docID, vec) // Rebuild the index
}
return nil
}
func (index *SparseIndex) bruteSearch(queryVec map[int32]float64, K int) []int32 {
scores := computeAllDistances(queryVec, index)
// 取top k
docHeap := &DocScoreHeap{}
for docID, score := range scores {
if docHeap.Len() < K {
heap.Push(docHeap, &DocScore{docID, score})
} else if (*docHeap)[0].score < score {
heap.Pop(docHeap)
heap.Push(docHeap, &DocScore{docID, score})
}
}
topDocs := make([]int32, 0, K)
for docHeap.Len() > 0 {
el := heap.Pop(docHeap).(*DocScore)
topDocs = append(topDocs, el.docID)
}
sort.Slice(topDocs, func(i, j int) bool {
return topDocs[i] < topDocs[j]
})
return topDocs
}
func computeAllDistances(queryVec map[int32]float64, index *SparseIndex) map[int32]float64 {
scores := make(map[int32]float64)
for term, qWeight := range queryVec {
if postingList, exists := index.invertedIndex[term]; exists {
for _, docID := range postingList {
docVec := index.docVectors[docID]
docWeight, exists := docVec[term]
if !exists {
continue
}
score := qWeight * docWeight
if _, ok := scores[docID]; !ok {
scores[docID] = score
} else {
scores[docID] += score
}
}
}
}
return scores
}
// TopK retrieves the top K documents nearest to the query vector
func (index *SparseIndex) WandSearch(queryVec map[int32]float64, K int) []int32 {
docHeap := &DocScoreHeap{}
// 倒排链
postingLists := make([]*Cursor, len(queryVec))
idx := 0
for term, termWeight := range queryVec {
if postingList, exists := index.invertedIndex[term]; exists {
// 包含term的doc,term对应的weight
weights := make([]float64, len(postingList))
for i, docID := range postingList {
weights[i] = index.docVectors[docID][term]
}
postingLists[idx] = NewCursor(postingList, weights, index.termMaxScores[term]*termWeight, termWeight)
idx += 1
}
}
sortPostings := func() {
for i := range postingLists {
if postingLists[i].IsEnd() {
return
}
}
// 将postingLists按照首个docid排序
sort.Slice(postingLists, func(i, j int) bool {
return postingLists[i].CurrentDocID() < postingLists[j].CurrentDocID()
})
}
sortPostings()
scoreAboveThreshold := func(value float64) bool {
return docHeap.Len() < K || (*docHeap)[0].score < value
}
for {
upperBound := 0.0
foundPivot := false
pivot := 0
for idx := range postingLists {
if postingLists[idx].IsEnd() {
break
}
upperBound += postingLists[idx].maxScore
if scoreAboveThreshold(upperBound) {
foundPivot = true
pivot = idx
break
}
}
if !foundPivot {
break
}
// 找到满足upper_bound 满足条件的pivot_id
pivotId := postingLists[pivot].CurrentDocID()
if pivotId == postingLists[0].CurrentDocID() {
// 如果第一个倒排链的当前vec_id (doc_id) 等于pivot_id,可以直接从第0个倒排链开始,计算score
score := 0.0
// 遍历所有cursors,累加score
for idx := range postingLists {
cursor := postingLists[idx]
if cursor.CurrentDocID() != pivotId {
break
}
score += cursor.CurrentDocWeight() * cursor.termWeight
// 移到下一个docid
postingLists[idx].Next()
}
// 放入堆s
if docHeap.Len() < K {
heap.Push(docHeap, &DocScore{pivotId, score})
} else if (*docHeap)[0].score < score {
heap.Pop(docHeap)
heap.Push(docHeap, &DocScore{pivotId, score})
}
// 重排cursors,保证最小的vec_id在最前面
sortPostings()
} else {
// 第一个倒排链的当前vec_id不等于pivot_id, pivot>=1
// 那么从pivot(满足threshold的倒排链序号)往前找是否有cur_vec_id==pivot_id的
nextList := pivot
for ; postingLists[nextList].CurrentDocID() == pivotId; nextList-- {
}
// 这里的next_list的cur_vec_id 不一定等与pivot_id,将list seek到pivot_id
// seek后,cursors[next_list].cur_vec_id() >= pivot_id,通过seek,可以跳过一些vec id
postingLists[nextList].Seek(pivotId)
// 从next_list + 1开始
for i := nextList + 1; i < len(postingLists); i++ {
// 如果当前cur_vec_id >= 上一个则停止
if postingLists[i].CurrentDocID() >= postingLists[i-1].CurrentDocID() {
break
}
// 否则,交换倒排链,可以确保==pivot_id的倒排链交换到前面
temp := postingLists[i]
postingLists[i] = postingLists[i-1]
postingLists[i-1] = temp
}
}
}
topDocs := make([]int32, 0, K)
for docHeap.Len() > 0 {
el := heap.Pop(docHeap).(*DocScore)
topDocs = append(topDocs, el.docID)
}
sort.Slice(topDocs, func(i, j int) bool {
return topDocs[i] < topDocs[j]
})
return topDocs
}
// Helper structure to manage the priority queue for the top-K documents
type DocScore struct {
docID int32
score float64
}
type DocScoreHeap []*DocScore
func (h DocScoreHeap) Len() int { return len(h) }
func (h DocScoreHeap) Less(i, j int) bool { return h[i].score < h[j].score }
func (h DocScoreHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *DocScoreHeap) Push(x interface{}) {
*h = append(*h, x.(*DocScore))
}
func (h *DocScoreHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
func main() {
index := NewSparseIndex()
rand.Seed(time.Now().UnixNano())
// Add document vectors as needed
for i := 1; i <= 1000; i++ {
// 打印当前i的值
index.Add(int32(i), map[int32]float64{101: rand.Float64(),
150: rand.Float64(),
190: rand.Float64(),
500: rand.Float64()})
}
//index.Save("index.bin")
//index.Load("index.bin")
topDocs := index.WandSearch(map[int32]float64{101: rand.Float64(), 150: rand.Float64(), 190: rand.Float64(),
500: rand.Float64()}, 10)
fmt.Println("Top Docs:", topDocs)
}
- 代码实现了索引的构建、保存和加载,检索方面实现了暴力检索和WAND检索
- 注意,添加doc时,需要保障doc有序,实际应用中,docid可以引擎维护一个真实id到递增docid的映射
- 代码中已经有注释,这里不再赘述,注意代码未充分调试,可能有bug
- 代码实现倒排表全放到内存,效率高,但对内存要求高
sparse 检索整体类似传统的文本检索,因此传统的工程优化方法可以运用到sparse检索中,本文分析了milvus的实现,并实现了一个golang版本的sparse检索。
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK