0

浅谈sparse vec检索工程化实现 - JadePeng

 3 weeks ago
source link: https://www.cnblogs.com/xiaoqi/p/18150639/golang-sparse-retrival
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

前面我们通过两篇文章: BGE M3-Embedding 模型介绍Sparse稀疏检索介绍与实践 介绍了sparse 稀疏检索,今天我们来看看如何建立一个工程化的系统来实现sparse vec的检索。

之前提过milvus最新的V2.4支持sparse检索,我们先看看milvus的实现。

milvus的sparse检索实现

milvus 检索底层引擎是knowhere,主要代码在src/index/sparse 里。

首先,通过数据结构SparseRow,用于表示稀疏向量,支持浮点数(float)类型的数据

class SparseRow {
    static_assert(std::is_same_v<T, fp32>, "SparseRow supports float only");

 public:
    // construct an SparseRow with memory allocated to hold `count` elements.
    SparseRow(size_t count = 0)
        : data_(count ? new uint8_t[count * element_size()] : nullptr), count_(count), own_data_(true) {
    }

    SparseRow(size_t count, uint8_t* data, bool own_data) : data_(data), count_(count), own_data_(own_data) {
    }

    // copy constructor and copy assignment operator perform deep copy
    SparseRow(const SparseRow<T>& other) : SparseRow(other.count_) {
        std::memcpy(data_, other.data_, data_byte_size());
    }

    SparseRow(SparseRow<T>&& other) noexcept : SparseRow() {
        swap(*this, other);
    }

    SparseRow&
    operator=(const SparseRow<T>& other) {
        if (this != &other) {
            SparseRow<T> tmp(other);
            swap(*this, tmp);
        }
        return *this;
    }

    SparseRow&
    operator=(SparseRow<T>&& other) noexcept {
        swap(*this, other);
        return *this;
    }

    ~SparseRow() {
        if (own_data_ && data_ != nullptr) {
            delete[] data_;
            data_ = nullptr;
        }
    }

    size_t
    size() const {
        return count_;
    }

    size_t
    memory_usage() const {
        return data_byte_size() + sizeof(*this);
    }

    // return the number of bytes used by the underlying data array.
    size_t
    data_byte_size() const {
        return count_ * element_size();
    }

    void*
    data() {
        return data_;
    }

    const void*
    data() const {
        return data_;
    }

    // dim of a sparse vector is the max index + 1, or 0 for an empty vector.
    int64_t
    dim() const {
        if (count_ == 0) {
            return 0;
        }
        auto* elem = reinterpret_cast<const ElementProxy*>(data_) + count_ - 1;
        return elem->index + 1;
    }

    SparseIdVal<T>
    operator[](size_t i) const {
        auto* elem = reinterpret_cast<const ElementProxy*>(data_) + i;
        return {elem->index, elem->value};
    }

    void
    set_at(size_t i, table_t index, T value) {
        auto* elem = reinterpret_cast<ElementProxy*>(data_) + i;
        elem->index = index;
        elem->value = value;
    }

    float
    dot(const SparseRow<T>& other) const {
        float product_sum = 0.0f;
        size_t i = 0;
        size_t j = 0;
        // TODO: improve with _mm_cmpistrm or the AVX512 alternative.
        while (i < count_ && j < other.count_) {
            auto* left = reinterpret_cast<const ElementProxy*>(data_) + i;
            auto* right = reinterpret_cast<const ElementProxy*>(other.data_) + j;

            if (left->index < right->index) {
                ++i;
            } else if (left->index > right->index) {
                ++j;
            } else {
                product_sum += left->value * right->value;
                ++i;
                ++j;
            }
        }
        return product_sum;
    }

    friend void
    swap(SparseRow<T>& left, SparseRow<T>& right) {
        using std::swap;
        swap(left.count_, right.count_);
        swap(left.data_, right.data_);
        swap(left.own_data_, right.own_data_);
    }

    static inline size_t
    element_size() {
        return sizeof(table_t) + sizeof(T);
    }

 private:
    // ElementProxy is used to access elements in the data_ array and should
    // never be actually constructed.
    struct __attribute__((packed)) ElementProxy {
        table_t index;
        T value;
        ElementProxy() = delete;
        ElementProxy(const ElementProxy&) = delete;
    };
    // data_ must be sorted by column id. use raw pointer for easy mmap and zero
    // copy.
    uint8_t* data_;
    size_t count_;
    bool own_data_;
};

然后索引具体是在InvertedIndex 类里, 对应sparse_inverted_index.h 文件,首先看定义的一些private 字段。

    std::vector<SparseRow<T>> raw_data_;
    mutable std::shared_mutex mu_;

    std::unordered_map<table_t, std::vector<SparseIdVal<T>>> inverted_lut_;
    bool use_wand_ = false;
    // If we want to drop small values during build, we must first train the
    // index with all the data to compute value_threshold_.
    bool drop_during_build_ = false;
    // when drop_during_build_ is true, any value smaller than value_threshold_
    // will not be added to inverted_lut_. value_threshold_ is set to the
    // drop_ratio_build-th percentile of all absolute values in the index.
    T value_threshold_ = 0.0f;
    std::unordered_map<table_t, T> max_in_dim_;
    size_t max_dim_ = 0;
  • raw_data_ 是原始的数据
  • inverted_lut_ 可以理解为一个倒排表
  • use_wand_ 用于控制查询时,是否使用WAND算法,WAND算法是经典的查询优化算法,可以通过类似跳表的方式跳过一些数据,减少计算量,提升查询效率
  • max_in_dim_ 是为wand服务的

索引构建流程

构建,主要是对外提供一个Add数据的方法:

    Status
    Add(const SparseRow<T>* data, size_t rows, int64_t dim) {
        std::unique_lock<std::shared_mutex> lock(mu_);
        auto current_rows = n_rows_internal();
        if (current_rows > 0 && drop_during_build_) {
            LOG_KNOWHERE_ERROR_ << "Not allowed to add data to a built index with drop_ratio_build > 0.";
            return Status::invalid_args;
        }
        if ((size_t)dim > max_dim_) {
            max_dim_ = dim;
        }

        raw_data_.insert(raw_data_.end(), data, data + rows);
        for (size_t i = 0; i < rows; ++i) {
            add_row_to_index(data[i], current_rows + i);
        }
        return Status::success;
    }

这里会更新数据的max_dim,数据追加到raw_data_,然后add_row_to_index,将新的doc放入inverted_lut_, 并更新max_in_dim_,用于记录最大值,方便wand查询时跳过计算。

    inline void
    add_row_to_index(const SparseRow<T>& row, table_t id) {
        for (size_t j = 0; j < row.size(); ++j) {
            auto [idx, val] = row[j];
            // Skip values close enough to zero(which contributes little to
            // the total IP score).
            if (drop_during_build_ && fabs(val) < value_threshold_) {
                continue;
            }
            if (inverted_lut_.find(idx) == inverted_lut_.end()) {
                inverted_lut_[idx];
                if (use_wand_) {
                    max_in_dim_[idx] = 0;
                }
            }
            inverted_lut_[idx].emplace_back(id, val);
            if (use_wand_) {
                max_in_dim_[idx] = std::max(max_in_dim_[idx], val);
            }
        }
    }

索引保存与load

保存时,是自定义的二进制文件:

    Status
    Save(MemoryIOWriter& writer) {
        /**
         * zero copy is not yet implemented, now serializing in a zero copy
         * compatible way while still copying during deserialization.
         *
         * Layout:
         *
         * 1. int32_t rows, sign indicates whether to use wand
         * 2. int32_t cols
         * 3. for each row:
         *     1. int32_t len
         *     2. for each non-zero value:
         *        1. table_t idx
         *        2. T val
         *     With zero copy deserization, each SparseRow object should
         *     reference(not owning) the memory address of the first element.
         *
         * inverted_lut_ and max_in_dim_ not serialized, they will be
         * constructed dynamically during deserialization.
         *
         * Data are densly packed in serialized bytes and no padding is added.
         */
        std::shared_lock<std::shared_mutex> lock(mu_);
        writeBinaryPOD(writer, n_rows_internal() * (use_wand_ ? 1 : -1));
        writeBinaryPOD(writer, n_cols_internal());
        writeBinaryPOD(writer, value_threshold_);
        for (size_t i = 0; i < n_rows_internal(); ++i) {
            auto& row = raw_data_[i];
            writeBinaryPOD(writer, row.size());
            if (row.size() == 0) {
                continue;
            }
            writer.write(row.data(), row.size() * SparseRow<T>::element_size());
        }
        return Status::success;
    }

索引文件格式:

    1. int32_t rows 总记录数,通过±符号来区分是否 use wand
    1. int32_t cols 列数
    1. for each row:
  • 1. int32_t len 长度
    
  • 2. for each non-zero value:
    
  •    1. table_t idx term的id编号
    
  •    2. T val   term的权重
    

注意,这里inverted_lut_倒排表是没有存储的,是在加载的时候重建,所以load的过程,就是一个逆过程:

Status
    Load(MemoryIOReader& reader) {
        std::unique_lock<std::shared_mutex> lock(mu_);
        int64_t rows;
        readBinaryPOD(reader, rows);
        use_wand_ = rows > 0;
        rows = std::abs(rows);
        readBinaryPOD(reader, max_dim_);
        readBinaryPOD(reader, value_threshold_);

        raw_data_.reserve(rows);

        for (int64_t i = 0; i < rows; ++i) {
            size_t count;
            readBinaryPOD(reader, count);
            raw_data_.emplace_back(count);
            if (count == 0) {
                continue;
            }
            reader.read(raw_data_[i].data(), count * SparseRow<T>::element_size());
            add_row_to_index(raw_data_[i], i);
        }

        return Status::success;
    }

我们来回顾,compute_lexical_matching_score其实就是计算共同term的weight score相乘,然后加起来,所以可以想象下,暴力检索大概就是把所有term对应的doc取并集,然后计算lexical_matching_score,最后取topk。

我们来看milvus的实现,先看暴力检索:

    // find the top-k candidates using brute force search, k as specified by the capacity of the heap.
    // any value in q_vec that is smaller than q_threshold and any value with dimension >= n_cols() will be ignored.
    // TODO: may switch to row-wise brute force if filter rate is high. Benchmark needed.
    void
    search_brute_force(const SparseRow<T>& q_vec, T q_threshold, MaxMinHeap<T>& heap, const BitsetView& bitset) const {
        auto scores = compute_all_distances(q_vec, q_threshold);
        for (size_t i = 0; i < n_rows_internal(); ++i) {
            if ((bitset.empty() || !bitset.test(i)) && scores[i] != 0) {
                heap.push(i, scores[i]);
            }
        }
    }

    std::vector<float>
    compute_all_distances(const SparseRow<T>& q_vec, T q_threshold) const {
        std::vector<float> scores(n_rows_internal(), 0.0f);
        for (size_t idx = 0; idx < q_vec.size(); ++idx) {
            auto [i, v] = q_vec[idx];
            if (v < q_threshold || i >= n_cols_internal()) {
                continue;
            }
            auto lut_it = inverted_lut_.find(i);
            if (lut_it == inverted_lut_.end()) {
                continue;
            }
            // TODO: improve with SIMD
            auto& lut = lut_it->second;
            for (size_t j = 0; j < lut.size(); j++) {
                auto [idx, val] = lut[j];
                scores[idx] += v * float(val);
            }
        }
        return scores;
    }
  • 核心在compute_all_distances里,先通过q_vec得到每一个term id,然后从inverted_lut_里找到term对应的doc list,然后计算score,相同doc id的score累加
  • 最后用MaxMinHeap堆,来取topk

暴力检索能保准精准性,但是效率比较低。我们来看使用wand优化的检索:

// any value in q_vec that is smaller than q_threshold will be ignored.
    void
    search_wand(const SparseRow<T>& q_vec, T q_threshold, MaxMinHeap<T>& heap, const BitsetView& bitset) const {
        auto q_dim = q_vec.size();
        std::vector<std::shared_ptr<Cursor<std::vector<SparseIdVal<T>>>>> cursors(q_dim);
        auto valid_q_dim = 0;
        // 倒排链
        for (size_t i = 0; i < q_dim; ++i) {
	        // idx(term_id)
            auto [idx, val] = q_vec[i];
            if (std::abs(val) < q_threshold || idx >= n_cols_internal()) {
                continue;
            }
            auto lut_it = inverted_lut_.find(idx);
            if (lut_it == inverted_lut_.end()) {
                continue;
            }
            auto& lut = lut_it->second;
            // max_in_dim_ 记录了term index 的最大score
            cursors[valid_q_dim++] = std::make_shared<Cursor<std::vector<SparseIdVal<T>>>>(
                lut, n_rows_internal(), max_in_dim_.find(idx)->second * val, val, bitset);
        }
        if (valid_q_dim == 0) {
            return;
        }
        cursors.resize(valid_q_dim);
        auto sort_cursors = [&cursors] {
            std::sort(cursors.begin(), cursors.end(),
                      [](auto& x, auto& y) { return x->cur_vec_id() < y->cur_vec_id(); });
        };
        sort_cursors();
        // 堆未满,或者新的score > 堆顶的score
        auto score_above_threshold = [&heap](float x) { return !heap.full() || x > heap.top().val; };
        while (true) {
	        // 上边界
            float upper_bound = 0;
            // pivot 满足条件的倒排链的序号
            size_t pivot;
            bool found_pivot = false;
            for (pivot = 0; pivot < cursors.size(); ++pivot) {
	            // 有倒排结束
                if (cursors[pivot]->is_end()) {
                    break;
                }
                upper_bound += cursors[pivot]->max_score();
                if (score_above_threshold(upper_bound)) {
                    found_pivot = true;
                    break;
                }
            }
            if (!found_pivot) {
                break;
            }
            // 找到满足upper_bound 满足条件的pivot_id
            table_t pivot_id = cursors[pivot]->cur_vec_id();
            // 如果第一个倒排链的当前vec_id (doc_id) 等于pivot_id,可以直接从第0个倒排链开始,计算score
            if (pivot_id == cursors[0]->cur_vec_id()) {
                float score = 0;
                // 遍历所有cursors,累加score
                for (auto& cursor : cursors) {
                    if (cursor->cur_vec_id() != pivot_id) {
                        break;
                    }
                    score += cursor->cur_distance() * cursor->q_value();
                    // 倒排链移到下一位
                    cursor->next();
                }
                // 放入堆
                heap.push(pivot_id, score);
                // 重排cursors,保证最小的vec_id在最前面
                sort_cursors();
            } else {
                // 第一个倒排链的当前vec_id不等于pivot_id, pivot>=1
                // 那么从pivot(满足threshold的倒排链序号)往前找是否有cur_vec_id==pivot_id的
                size_t next_list = pivot;
                for (; cursors[next_list]->cur_vec_id() == pivot_id; --next_list) {
                }
                // 这里的next_list的cur_vec_id 不一定等与pivot_id,将list seek到pivot_id
                // seek后,cursors[next_list].cur_vec_id() >= pivot_id,通过seek,可以跳过一些vec id
                cursors[next_list]->seek(pivot_id);
                // 从next_list + 1开始
                for (size_t i = next_list + 1; i < cursors.size(); ++i) {
                    // 如果当前cur_vec_id >= 上一个则停止
                    if (cursors[i]->cur_vec_id() >= cursors[i - 1]->cur_vec_id()) {
                        break;
                    }
                    // 否则,交换倒排链,可以确保==pivot_id的倒排链交换到前面
                    std::swap(cursors[i], cursors[i - 1]);
                }
            }
        }
    }
  • 首先是倒排链取出来放入cursors,然后对cursors按照vec_id排序,将vec_id较小的排到倒排链的首位
  • 通过score_above_threshold,遍历cursors找符合条件的cursor 索引号pivot,这里通过堆未满,或者新的score > 堆顶的score来判断,可以跳过一些score小的
  • 然后找到pivot cursor对应的pivot_id,也就是doc id,然后判断第一个倒排链的cur_vec_id 是否等于pivot_id:
    • 如果等于,就可以遍历倒排链,计算pivot_id的score,然后放入小顶堆中排序,然后重排倒排链
    • 如果不等于,那么就需要想办法将cur_vec_id == pivot_id的往前放,同时跳过倒排链中vec_id < cur_vec_id的数据(减枝)

用golang实现轻量级sparse vec检索

用类似milvus的方法,我们简单实现一个golang版本的

package main

import (
	"container/heap"
	"encoding/binary"
	"fmt"
	"io"
	"math/rand"
	"os"
	"sort"
	"time"
)

type Cursor struct {
	docIDs     []int32
	weights    []float64
	maxScore   float64
	termWeight float64
	currentIdx int
}

func NewCursor(docIDs []int32, weights []float64, maxScore float64, weight float64) *Cursor {
	return &Cursor{
		docIDs:     docIDs,
		weights:    weights,
		maxScore:   maxScore,
		termWeight: weight,
		currentIdx: 0,
	}
}

func (c *Cursor) Next() {
	c.currentIdx++
}

func (c *Cursor) Seek(docId int32) {
	for {
		if c.IsEnd() {
			break
		}
		if c.CurrentDocID() < docId {
			c.Next()
		} else {
			break
		}
	}
}

func (c *Cursor) IsEnd() bool {
	return c.currentIdx >= len(c.docIDs)
}

func (c *Cursor) CurrentDocID() int32 {
	return c.docIDs[c.currentIdx]
}

func (c *Cursor) CurrentDocWeight() float64 {
	return c.weights[c.currentIdx]
}

// DocVectors type will map docID to its vector
type DocVectors map[int32]map[int32]float64

// InvertedIndex type will map termID to sorted list of docIDs
type InvertedIndex map[int32][]int32

// TermMaxScore will keep track of maximum scores for terms
type TermMaxScores map[int32]float64

// SparseIndex class struct
type SparseIndex struct {
	docVectors    DocVectors
	invertedIndex InvertedIndex
	termMaxScores TermMaxScores
	dim           int32
}

// NewSparseIndex initializes a new SparseIndex with empty structures
func NewSparseIndex() *SparseIndex {
	return &SparseIndex{
		docVectors:    make(DocVectors),
		invertedIndex: make(InvertedIndex),
		termMaxScores: make(TermMaxScores),
		dim:           0,
	}
}

// Add method for adding documents to the sparse index
func (index *SparseIndex) Add(docID int32, vec map[int32]float64) {
	index.docVectors[docID] = vec

	for termID, score := range vec {
		index.invertedIndex[termID] = append(index.invertedIndex[termID], docID)

		// Track max score for each term
		if maxScore, ok := index.termMaxScores[termID]; !ok || score > maxScore {
			index.termMaxScores[termID] = score
		}
		if termID > index.dim {
			index.dim = termID
		}
	}
}

// Save index to file
func (index *SparseIndex) Save(filename string) error {
	file, err := os.Create(filename)
	if err != nil {
		return err
	}
	defer file.Close()

	// Write the dimension
	binary.Write(file, binary.LittleEndian, index.dim)

	// Write each document vector
	for docID, vec := range index.docVectors {
		binary.Write(file, binary.LittleEndian, docID)
		vecSize := int32(len(vec))
		binary.Write(file, binary.LittleEndian, vecSize)

		for termID, score := range vec {
			binary.Write(file, binary.LittleEndian, termID)
			binary.Write(file, binary.LittleEndian, score)
		}
	}

	return nil
}

// Load index from file
func (index *SparseIndex) Load(filename string) error {
	file, err := os.Open(filename)
	if err != nil {
		return err
	}
	defer file.Close()

	var dim int32
	binary.Read(file, binary.LittleEndian, &dim)
	index.dim = dim

	for {
		var docID int32
		err := binary.Read(file, binary.LittleEndian, &docID)
		if err == io.EOF {
			break // End of file
		} else if err != nil {
			return err // Some other error
		}

		var vecSize int32
		binary.Read(file, binary.LittleEndian, &vecSize)
		vec := make(map[int32]float64)

		for i := int32(0); i < vecSize; i++ {
			var termID int32
			var score float64
			binary.Read(file, binary.LittleEndian, &termID)
			binary.Read(file, binary.LittleEndian, &score)
			vec[termID] = score
		}

		index.Add(docID, vec) // Rebuild the index
	}
	return nil
}

func (index *SparseIndex) bruteSearch(queryVec map[int32]float64, K int) []int32 {

	scores := computeAllDistances(queryVec, index)

	// 取top k
	docHeap := &DocScoreHeap{}
	for docID, score := range scores {
		if docHeap.Len() < K {
			heap.Push(docHeap, &DocScore{docID, score})
		} else if (*docHeap)[0].score < score {
			heap.Pop(docHeap)
			heap.Push(docHeap, &DocScore{docID, score})
		}
	}

	topDocs := make([]int32, 0, K)
	for docHeap.Len() > 0 {
		el := heap.Pop(docHeap).(*DocScore)
		topDocs = append(topDocs, el.docID)
	}

	sort.Slice(topDocs, func(i, j int) bool {
		return topDocs[i] < topDocs[j]
	})

	return topDocs
}

func computeAllDistances(queryVec map[int32]float64, index *SparseIndex) map[int32]float64 {
	scores := make(map[int32]float64)
	for term, qWeight := range queryVec {
		if postingList, exists := index.invertedIndex[term]; exists {
			for _, docID := range postingList {
				docVec := index.docVectors[docID]
				docWeight, exists := docVec[term]
				if !exists {
					continue
				}
				score := qWeight * docWeight

				if _, ok := scores[docID]; !ok {
					scores[docID] = score
				} else {
					scores[docID] += score
				}
			}
		}
	}
	return scores
}

// TopK retrieves the top K documents nearest to the query vector
func (index *SparseIndex) WandSearch(queryVec map[int32]float64, K int) []int32 {
	docHeap := &DocScoreHeap{}

	// 倒排链
	postingLists := make([]*Cursor, len(queryVec))
	idx := 0
	for term, termWeight := range queryVec {
		if postingList, exists := index.invertedIndex[term]; exists {
			// 包含term的doc,term对应的weight
			weights := make([]float64, len(postingList))
			for i, docID := range postingList {
				weights[i] = index.docVectors[docID][term]
			}
			postingLists[idx] = NewCursor(postingList, weights, index.termMaxScores[term]*termWeight, termWeight)
			idx += 1
		}
	}

	sortPostings := func() {
		for i := range postingLists {
			if postingLists[i].IsEnd() {
				return
			}
		}
		// 将postingLists按照首个docid排序
		sort.Slice(postingLists, func(i, j int) bool {
			return postingLists[i].CurrentDocID() < postingLists[j].CurrentDocID()
		})
	}

	sortPostings()

	scoreAboveThreshold := func(value float64) bool {
		return docHeap.Len() < K || (*docHeap)[0].score < value
	}

	for {
		upperBound := 0.0
		foundPivot := false
		pivot := 0
		for idx := range postingLists {
			if postingLists[idx].IsEnd() {
				break
			}

			upperBound += postingLists[idx].maxScore
			if scoreAboveThreshold(upperBound) {
				foundPivot = true
				pivot = idx
				break
			}
		}

		if !foundPivot {
			break
		}

		// 找到满足upper_bound 满足条件的pivot_id
		pivotId := postingLists[pivot].CurrentDocID()
		if pivotId == postingLists[0].CurrentDocID() {
			//	如果第一个倒排链的当前vec_id (doc_id) 等于pivot_id,可以直接从第0个倒排链开始,计算score
			score := 0.0
			// 遍历所有cursors,累加score
			for idx := range postingLists {
				cursor := postingLists[idx]
				if cursor.CurrentDocID() != pivotId {
					break
				}
				score += cursor.CurrentDocWeight() * cursor.termWeight
				// 移到下一个docid
				postingLists[idx].Next()
			}

			// 放入堆s
			if docHeap.Len() < K {
				heap.Push(docHeap, &DocScore{pivotId, score})
			} else if (*docHeap)[0].score < score {
				heap.Pop(docHeap)
				heap.Push(docHeap, &DocScore{pivotId, score})
			}

			// 重排cursors,保证最小的vec_id在最前面
			sortPostings()
		} else {
			// 第一个倒排链的当前vec_id不等于pivot_id, pivot>=1
			// 那么从pivot(满足threshold的倒排链序号)往前找是否有cur_vec_id==pivot_id的
			nextList := pivot
			for ; postingLists[nextList].CurrentDocID() == pivotId; nextList-- {
			}
			// 这里的next_list的cur_vec_id 不一定等与pivot_id,将list seek到pivot_id
			// seek后,cursors[next_list].cur_vec_id() >= pivot_id,通过seek,可以跳过一些vec id
			postingLists[nextList].Seek(pivotId)
			// 从next_list + 1开始

			for i := nextList + 1; i < len(postingLists); i++ {
				// 如果当前cur_vec_id >= 上一个则停止
				if postingLists[i].CurrentDocID() >= postingLists[i-1].CurrentDocID() {
					break
				}
				// 否则,交换倒排链,可以确保==pivot_id的倒排链交换到前面
				temp := postingLists[i]
				postingLists[i] = postingLists[i-1]
				postingLists[i-1] = temp
			}
		}
	}

	topDocs := make([]int32, 0, K)
	for docHeap.Len() > 0 {
		el := heap.Pop(docHeap).(*DocScore)
		topDocs = append(topDocs, el.docID)
	}
	sort.Slice(topDocs, func(i, j int) bool {
		return topDocs[i] < topDocs[j]
	})

	return topDocs
}

// Helper structure to manage the priority queue for the top-K documents
type DocScore struct {
	docID int32
	score float64
}

type DocScoreHeap []*DocScore

func (h DocScoreHeap) Len() int           { return len(h) }
func (h DocScoreHeap) Less(i, j int) bool { return h[i].score < h[j].score }
func (h DocScoreHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }

func (h *DocScoreHeap) Push(x interface{}) {
	*h = append(*h, x.(*DocScore))
}

func (h *DocScoreHeap) Pop() interface{} {
	old := *h
	n := len(old)
	x := old[n-1]
	*h = old[0 : n-1]
	return x
}

func main() {
	index := NewSparseIndex()

	rand.Seed(time.Now().UnixNano())
	// Add document vectors as needed
	for i := 1; i <= 1000; i++ {
		// 打印当前i的值
		index.Add(int32(i), map[int32]float64{101: rand.Float64(),
			150: rand.Float64(),
			190: rand.Float64(),
			500: rand.Float64()})
	}
	//index.Save("index.bin")
	//index.Load("index.bin")
	topDocs := index.WandSearch(map[int32]float64{101: rand.Float64(), 150: rand.Float64(), 190: rand.Float64(),
		500: rand.Float64()}, 10)
	fmt.Println("Top Docs:", topDocs)
}

  • 代码实现了索引的构建、保存和加载,检索方面实现了暴力检索和WAND检索
  • 注意,添加doc时,需要保障doc有序,实际应用中,docid可以引擎维护一个真实id到递增docid的映射
  • 代码中已经有注释,这里不再赘述,注意代码未充分调试,可能有bug
  • 代码实现倒排表全放到内存,效率高,但对内存要求高

sparse 检索整体类似传统的文本检索,因此传统的工程优化方法可以运用到sparse检索中,本文分析了milvus的实现,并实现了一个golang版本的sparse检索。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK