Background
LSM-KV 多年来一直是我们学校软工大二下高级数据结构这门课的 Project,具体内容就是自己手搓一个 LSM-tree。在 大量老学长老学姐在水源上吐槽软工任务量大且没用 今年课程改革后,老师也对这个 Project 进行了一些改造,不再要求我们从零手搓一个完整的 LSM-tree 了,同时结合当下大模型的热点,要求我们实现语义搜索的功能。
这个 Project 是分阶段发布的,每个阶段要求我们实现一部分功能,不得不说 ddl 虽然没有特别紧但也不怎么轻松。(dmk:以前给大家多一两周最后发现大家也是 ddl 前两三天才开始啊
LSM-tree 的具体原理网上有大量清晰的讲解,写之前我好好学了下。整个 Project 的代码我都放在我 GitHub 上的仓库了。
Phase 1
日志结构合并树(Log-Structured Merge-Tree,简称 LSM-tree) 是一种可以高性能执行大量写操作的数据结构,它于1996年在 Patrick O’Neil 等人的一篇论文1中被提出,是近年来存储学术会议中的热门话题。Google 的开源项目 LevelDB 和 Meta 的开源项目 RocksDB 都以 LSM-tree 作为核心数据结构。
阶段一是比较传统的部分,但是砍掉了大量的工作,我们需要实现的是:
skiplist类,作为 memtable 的数据结构;compaction()函数,当插入或删除时可能触发的合并操作;fetchString()函数,用于查询时从硬盘上的文件中读出相应的值。
skip list
跳表是非常经典的实现,构造函数需传入一个增长率p。跳表中定义了一个头结点和尾结点,不保存任何数据,仅用于索引(这是我喜欢的方式,操作起来真的方便很多)。
skip list node
跳表的节点类定义如下:
class slnode {
public:
uint64_t key;
std::string val;
TYPE type;
std::vector<slnode *> nxt;
slnode(uint64_t key, const std::string &val, TYPE type) {
this->key = key;
this->val = val;
this->type = type;
for (int i = 0; i < MAX_LEVEL; ++i)
nxt.push_back(nullptr);
}
};其中数组std::vector<slnode *> nxt用于保存该节点在每一层的后继结点。
search
跳表中实现起来最简单的操作,根据跳表的原理,在每一层搜索到key正好小于待查key的节点,判断它在该层的后继结点是否等于待查key,不符合跳到下一层继续搜索即可。如果到了底层的尾结点都还没搜索到,说明不存在。
insert
执行插入操作时,维护一个数组std::vector<slnode *> update,类似于搜索操作,只不过要把每层跳转的位置记录下来。然后再统一把update中每个节点的后继结点更新为插入节点即可。
这其中有一种特殊的情况,就是当update[0]的后继结点的key正好等于待插入节点的key,这时候就不需要新建节点,把update[0]->val更新为待插入节点的val即可。
delete
由于在 LSM-tree 中,删除操作并不是真的删掉了某个键值对,而是插入一个 “~DELETE~” 字符串,这一功能事实上已经在insert()中实现了,所以我在询问助教后就直接偷懒没实现del(),当然还包括了另一个没用到的lowerBound()(逃
scan
这个功能是取出跳表中所有key介于key1和key2之间的节点,其实和 search 没差多少,就是找到边界的两个节点,然后在底层中把它们两个之间的全部取出来就行。
reset
和链表一样,在底层把除了头尾节点的节点逐个删掉即可。
fetch string
之所以先写这部分,当然是因为它比较简单啦,都是 C 标准库的文件操作,直接让 AI 教我就行。
std::string KVStore::fetchString(std::string file, int startOffset, uint32_t len) {
// TODO here
FILE *fp = fopen(file.c_str(), "rb");
fseek(fp, startOffset, SEEK_SET);
fread(strBuf, 1, len, fp);
fclose(fp);
return std::string(strBuf, len);
}compaction
一开始做感觉似乎很繁琐,真的写起来发现其实也挺有意思。这个 Project 要求每个 sstable 的大小不超过 2MB,也就是在 memtable 满 2MB 的时候要把它整个转换成 sstable 并写入硬盘,这个时候就有可能会触发合并操作。硬盘是分层存储的,要求第 层有不超过 个 sstable,当 memtable 转换为 sstable 存入 Level 0 导致该层的 sstable 个数超出时,就要把该层的所有 sstable 和下一层key范围有相交的 sstable 合并后放入 Level 1。从 Level 1 开始,如果该层的 sstable 个数超出,只需考虑时间戳最小的几个即可。
用一个while循环判断是否在当前层触发合并操作:
while (sstableIndex[curLevel].size() > 1 << curLevel + 1) {
std::vector<sstablehead> targets; // 需要合并的所有sstable
uint64_t start = INF, end = 0; // 当前层要合并的sstable覆盖的区间
int compactionNum; // 当前层要合并的sstable数
// ...
}然后计算当前层要合并的 sstable 数,而 sstablehead 本来就是存在内存里的,只需取出sstableIndex[curLevel]的前compactionNum个加入targets并更新start和end即可,因为时间戳小的本身就在前面。
接下来是判断下一层,如果不存在下一层,那么就需要新建一级目录;否则就要在下一层里查找key范围相交的 sstable 也加入targets中。
然后就是处理涉及到的所有键值对了,维护一棵红黑树(因为需要频繁地搜索和插入,选了一个综合速度比较快的)用于存储拿出来的所有键值对:
// 存储需要合并的所有数据,第一个参数是key,第二个是val,第三个是time
std::map<uint64_t, std::pair<std::string, uint64_t>> datas;对每一个键值对,插入前先检查datas里有没有,没有的话直接插入,有的话保留时间戳大的,包括 “~DELETE~” 标记。
处理完全部数据后再把targets里对应的 sstable 逐一删除。
最后就是在下一层新建 sstable 了,把datas里的数据逐一取出来后按顺序放入 sstable,每满一个 sstable 就往硬盘里写一次,最后还要把剩下不满一个 sstable 的键值对也写入硬盘。
让while循环继续处理下一层。
Phase 2
这一阶段要把 LSM-tree 升级成 Smart LSM-tree,也就是支持语义搜索,比如我搜索 “fruit”,能够查到存过的 “apple”、“banana”、“orange” 等,其实这一阶段实现的功能已经和传统的 LSM-tree 没啥关系了。
warm up
根据指示,需要先在当前目录下安装一个子模块:
git submodule add https://gitee.com/ShadowNearby/llama.cpp.git third_party/llama.cpp这一部分是为了让我们先熟悉熟悉 llama.cpp 的使用,在接下来的部分中将使用它来运行老师提供的模型,对每个插入的val生成嵌入式向量,通过计算并比较余弦相似度的方式来查找语义最接近的结果。
implementation
在kvstore_api.h中定义了一个新的接口:
/**
* Search the nearest k key-value pairs to the given query.
* The result should be sorted by cosine similarity in decending order.
*/
virtual std::vector<std::pair<std::uint64_t, std::string>> search_knn(std::string query, int k) = 0;用于查找 k 个语义最接近的val。
embedding
由于生成向量的耗时较长,这部分要求把向量全部保存在内存中,不需要考虑向量的持久化。需要在以下两个类中各加入一个新的成员。
class slnode {
std::vector<float> vec; // embedding for value
};
class KVStore {
std::vector<std::vector<std::vector<float>>> vecs[15]; // embedding for each value
};然后在插入到 memtable 或新建 sstable 时对每个val调用一次embedding_single()即可。频繁加载模型导致性能较低,助教提出可以参考模型的特性,一次处理多个请求,但是我懒得搞了,而且我感觉速度也还行
k nearest neighbor search
根据指示,在本阶段只需要按顺序逐一查找,找到相似度最高的 k 个即可。
我的做法是维护一个小顶堆(大小不超过 k),然后逐一取出向量并计算它们与query对应的向量的余弦相似度并压入堆中,同时存入堆中的数据还有对应的val的索引,最后只需要按索引从硬盘上取出堆中的 k 个元素的val即可,不用取出所有val。
一开始我只记得从所有的 sstable 里去搜索了,结果跑出来正确率是100%,也是逆天。后来我想起来还要在 memtable 里去搜,又把这部分给补上了,好在正确率没变。
Phase 3
HNSW 是一种多层图索引结构,是 NSW 算法的改进版本。NSW 的核心思想是将数据库中的向量与接近的向量相连,形成所谓的 “Small World”,然后在这个连通图上从某个起始节点开始,通过选择逐渐靠近目标节点的边进行导航。在此基础上,HNSW 借鉴了跳表的思想,采用分级存储特征向量的方式。在高层级存储较少的向量,这使得导航过程能够在高层级快速地靠近目标节点,从而提高查询的效率。HNSW 被 Meta 运用到了他们的 Faiss 库中。
在这一阶段,我们需要构建一个 HNSW,把上一阶段生成的向量都插入 HNSW 中,再使用 HNSW 进行查询,和暴力搜索比较正确率和效率。
declaration
为了模块化实现 HNSW,我新建了hnsw.cpp和hnsw.h两个文件,其中类的声明如下:
#ifndef HNSW_H
#define HNSW_H
#include <vector>
class HNSWNode {
public:
int level; // 当前节点的层级
uint64_t key;
std::string val;
std::vector<float> vec;
// 每层中的邻居节点和与该邻居节点的距离(注意:这里的距离是余弦相似度,余弦相似度越大,距离越小)
std::vector<std::vector<std::pair<float, HNSWNode *>>> neighbors;
HNSWNode(int level, uint64_t key, const std::string &val, const std::vector<float> &vec)
: level(level), key(key), val(val), vec(vec) {
neighbors.resize(level + 1);
}
bool operator<(const HNSWNode &other) const {
return true;
}
bool operator>(const HNSWNode &other) const {
return true;
}
};
class HNSW {
private:
int M;
int M_max;
int efConstruction;
int m_L;
HNSWNode *entry_point;
public:
HNSW() {}
HNSW(int M, int M_max, int efConstruction, int m_L)
: M(M), M_max(M_max), efConstruction(efConstruction), m_L(m_L) {
entry_point = nullptr;
}
int rand_level();
void insert(uint64_t key, const std::string &val, const std::vector<float> &vec);
std::vector<std::pair<uint64_t, std::string>> search(std::string query, int k);
};
#endif // HNSW_H这里面有个小坑,正常来说距离越大是越远的,但由于这里比较用的是余弦相似度,所以越大距离反而越近,当时写的时候还弄错了几个大于号小于号。HNSWNode类中运算符重载应该是用不到,因为不太可能算出两个完全相等的余弦相似度,主要是以防万一。
implementation
random level
HNSW 的论文和网上的博文都推荐在插入新节点时随机生成的层数 ,其中 是图的最大层数,有同学推荐直接使用 上的均匀分布,正确率一下就上来了,但在我的实现里测出来还是几何分布的效果会好得多。
insert
插入操作实现起来并不困难,在 及以上层级,每层都从入口点导航到该层离待插入节点最近的节点;从 到底层,每层在导航过程中都维护一个堆,保存不超过 个邻居,最后从中选择不超过 个进行连接。需要注意的是,如果其中选择的某个邻居已经连接了 个节点,需要判断它到最远邻居的距离是否比它到待插入节点的距离远,如果是,则需要替换;如果不是,则使待插入节点从堆中继续选择下一个节点。替换节点时需要维持连接的一致性,得把两个节点的邻居互相删了,一开始我忘了这回事,只删了一边,补上后正确率直接提高了一个百分点左右。
search
我一开始实现的搜索算法是按照文档里给出来的那样,和插入操作类似,算法如下:
std::vector<std::pair<uint64_t, std::string>> HNSW::search(std::string query, int k) {
std::vector<float> query_vec = embedding_single(query);
// 自顶层向底层逐层搜索,导航到离待搜索节点最近的节点
HNSWNode *cur = entry_point;
for (int i = entry_point->level; i >= 1; --i) {
float max_sim = common_embd_similarity_cos(query_vec.data(), cur->vec.data(), query_vec.size());
while (true) {
bool flag = false;
for (auto neighbor : cur->neighbors[i]) {
float sim = common_embd_similarity_cos(query_vec.data(), neighbor.second->vec.data(), query_vec.size());
if (sim > max_sim) {
max_sim = sim;
cur = neighbor.second;
flag = true;
}
}
if (!flag) break;
}
}
// 在最底层进行 k 近邻搜索
std::priority_queue<
std::pair<float, HNSWNode *>,
std::vector<std::pair<float, HNSWNode *>>,
std::greater<std::pair<float, HNSWNode *>>> candidates; // 用一个小顶堆来存储候选节点
float max_sim = common_embd_similarity_cos(query_vec.data(), cur->vec.data(), query_vec.size());
candidates.push(std::make_pair(max_sim, cur));
while (true) {
bool flag = false;
for (auto node : cur->neighbors[0]) {
float sim = common_embd_similarity_cos(query_vec.data(), node.second->vec.data(), query_vec.size());
// 把搜索过的邻居都放入候选
if (candidates.size() < k) {
candidates.push(std::make_pair(sim, node.second));
} else {
if (sim > candidates.top().first) {
candidates.pop();
candidates.push(std::make_pair(sim, node.second));
}
}
if (sim > max_sim) {
max_sim = sim;
cur = node.second;
flag = true;
}
}
if (!flag) break;
}
std::vector<std::pair<uint64_t, std::string>> ans;
while (!candidates.empty()) {
auto cur = candidates.top();
candidates.pop();
ans.push_back(std::make_pair(cur.second->key, cur.second->val));
}
std::reverse(ans.begin(), ans.end());
return ans;
}结果测出来的正确率十分逆天,只有百分之十几,当时心都死了。
于是只好去问 Claude Sonnet 3.5,它给了我一个下降时维持多个候选路径的办法,正确率瞬间就升到了百分之八九十。
test
测试也真是给我测破防了,不管哪种搜索方法,测出来的时间都比 Phase 2 的暴力搜索远远慢得多,这一定是因为样本量太小了,是的一定是这样的。 暴力搜索测出来的平均时间大概在 0.1 ms 左右,而 HNSW 搜索的时间竟然到了 3 ms 左右,这我写了两百多行实现了个啥啊(不管那么多了最后报告就如实写完交了)。如果增大 或 ,正确率几乎能接近100%,但是搜索时间甚至到了 10+ ms。其它参数的分析都在这篇抽象至极的报告里了。
反正就是挺逆天的,助教开了个 Project 问题汇总文档,有同学提到提高正确率的两个措施,一个是关掉编译优化,一个是上面提到的层数生成。关编译优化也对我的正确率没有任何影响,只是让程序的运行速度显著慢了很多。以及助教重写了一份embedding.cc,以提高向量的生成效率,然而我直接用却会报错。因为我们是第一届写这个新 Project,总之就是问题一大堆,但确实也能学到一些些有趣的东西。
Phase 4
原先的 LSM-tree 持久化的部分只有传统的分层存储的键值对的部分,在本阶段中,需要对已经做好 embedding 的向量也保存在磁盘中。除此之外,还要把构建好的 HNSW 索引也保存在磁盘中,下次启动时能够直接从磁盘上载入 HNSW 索引,省去了重新构建的步骤。
vector persistence
按照要求,向量在磁盘上应统一保存在一个文件内,文件头是一个8字节的dimension(在本 Project 中固定为768)。接下来由重复的数据块构成:8字节的key和768 × 4字节的vec。向量不会被删除,只会不断地追加到文件尾,读取时只需要从后往前找到第一个对应的key,那么此处的vec就是对应的向量。
我的实现方式是,在skiplist类中新增一个方法putEmbeddingFile(),在每次 memtable 被写入磁盘中时同时调用一次,这样就能存下所有的向量了。
在系统重启载入数据时,对每一个key都从文件末遍历,找到对应的向量并载入到内存中。
HNSW delete
本阶段还需要支持 HNSW 索引的删除操作,采取的方式是 Lazy Delete,即维护一个数组,存储已经被删掉的节点,查找时如果发现节点在该数组中,则舍弃。
这里我又小小偷懒了下,没做删除后重新插入的情况的处理
HNSW persistence
HSNW 索引持久化的目录结构如图所示:
hnsw_data/
├── global_header.bin # 全局参数文件(同原HNSWHeader结构)
├── deleted_nodes.bin # 被删除的节点数据
├── nodes/ # 节点数据存储目录
│ ├── 0/ # 节点0的数据
│ │ ├── header.bin # 向量数据(float32数组)
│ │ └── edges/ # 邻接表目录
│ │ ├── 3.bin # 第3层邻接表
│ │ ├── 2.bin # 第2层邻接表
│ │ └── ... # 其他存在的层级
│ ├── 1/ # 节点1的数据
│ └── ... # 其他节点
才发现这个主题渲染的行间距这么宽
我自己做了些小改动,在global_header.bin中我没有存全图最高层级,而是存了查询入口点的 id,一方面,入口点的高度就是全图最高层级,另一方面,重新载入时也需要确保索引的入口点不变。
感觉没有什么非常特别的,就照着要求实现就可以了,因为存入和加载的 HNSW 结构是完全一致的,所以测出来的 accept rate 也不变。
bonus
由于很多人反映测试集太小没法体现性能优势,而在本地做大量的 embedding 又不现实,助教在本阶段提供了样本量为100k的已经 embedding 好的数据,让我们尝试持久化这个大数据集,以供后续阶段的测试用。
这直接在hnsw_data/nodes/目录下生成了十万个文件夹,跑的时候电脑烫得感觉要炸了
Phase 5
本阶段要求我们实现并行化操作,只需要实现其中的某一项功能即可。笑死,最近 ICS 和 ADS 正好都在讲多线程和并行
助教给了我们两个参考程序,都是利用了std::future和std::async,简单跑了下看看效果,决定采用类似的方式完成这一阶段。
parallel get
我选了我觉得最简单的get()操作来实现并行化,逻辑是参考普通的get()操作的,代码如下:
struct parallel_pair {
bool flag;
std::string res;
uint64_t time;
};
std::string KVStore::parallel_get(uint64_t key) {
std::string res = s->search(key);
if (res.length()) { // 在memtable中找到, 或者是deleted,说明最近被删除过,
// 不用查sstable
if (res == DEL)
return "";
return res;
}
std::vector<std::future<parallel_pair>> parallel_futures;
for (int level = 0; level <= totalLevel; ++level) {
for (int j = 0; j < sstableIndex[level].size(); ++j) {
parallel_futures.push_back(std::async(std::launch::async, [=]() {
parallel_pair pair = {false, "", 0};
sstablehead ssh = sstableIndex[level][j];
if (key < ssh.getMinV() || key > ssh.getMaxV())
return pair;
uint32_t len;
int offset = ssh.searchOffset(key, len);
if (offset == -1)
return pair;
pair.flag = true;
pair.res = fetchString(ssh.getFilename(), offset + 32 + 10240 + 12 * ssh.getCnt(), len);
pair.time = ssh.getTime();
return pair;
}));
}
}
res = "";
std::uint64_t time = 0;
for (auto &future : parallel_futures) {
parallel_pair cur = future.get();
if (cur.flag && cur.time > time) {
time = cur.time;
res = cur.res;
}
}
return res;
}结构体中的flag用来标记在本次异步任务中是否查到了符合条件的key;time用来标记时间戳,保留最大的那个。
测试初次的实现时,出了个 bug,有的时候会有几个测试案例没有查到key,有的时候却又能全部 pass。遂问 AI,原因是我最初的循环写的是for (sstablehead it : sstableIndex[level]),这有可能会导致it的生命周期已经结束了,但是它需要做的异步任务却还没完成,lambda 表达式中捕获所有变量的引用,这时候就发生了悬垂引用。于是改成在 lambda 里面才根据位置取对应的sstablehead,问题就解决了。
最后测试结果:并行get()只用了普通get()不到一半的查询时间,终于写出个符合预期的真能用的了。
Review
这学期 ADS 的 Project 是首次革新,因此有好多让人做起来觉得挺奇怪的地方。特别是到后面几个 Phase,感觉真就是在面向助教编程,很刻意地在添加某些功能。但非要说收获吧,确实也不小。过几天就是答辩了,希望能顺利通过。
完结撒花