目录
一、代码解析
1.1 searchTest.ts
1.2 controller.ts
本文接上一篇文章FastGPT 知识库搜索测试功能解析 对具体代码进行解析。
一、代码解析
FastGPT 知识库的搜索测试功能主要涉及两个文件,分别是 searchTest.ts 和 controller.ts 文件,下面分别进行介绍。
1.1 searchTest.ts
文件路径是 projects/app/src/pages/api/core/dataset/searchTest.ts,搜索测试功能的主文件,代码如下所示。
async function handler(req: NextApiRequest) {
console.log("function handler(req: NextApiRequest)")
const {
datasetId, // 知识库 id
text, // 搜索测试框输入的检索文本
limit = 1500, // 引用的 token 上限
similarity, // 最低相关度,默认是0
searchMode, // 检索模式,例如:
usingReRank, // 是否对召回文本进行相关性重排,需要结合rerank模型;
datasetSearchUsingExtensionQuery = false, // 是否开启问题补全;
datasetSearchExtensionModel, // 问题补全所用的模型;
datasetSearchExtensionBg = '' // 问题补全的对话背景描述;
} = req.body as SearchTestProps;
// 判断知识库 id 以及检索文本是否为空
if (!datasetId || !text) {
return Promise.reject(CommonErrEnum.missingParams);
}
// 计时
const start = Date.now();
// auth dataset role 查询是否有读数据库的权限(ReadPermissionVal 读取权限值)
const { dataset, teamId, tmbId, apikey } = await authDataset({
req,
authToken: true,
authApiKey: true,
datasetId,
per: ReadPermissionVal
});
// auth balance
await checkTeamAIPoints(teamId);
// 获取补全模型
const extensionModel =
datasetSearchUsingExtensionQuery && datasetSearchExtensionModel
? getLLMModel(datasetSearchExtensionModel)
: undefined;
// 问题通过LLM进行补全
const { concatQueries, rewriteQuery, aiExtensionResult } = await datasetSearchQueryExtension({
query: text,
extensionModel,
extensionBg: datasetSearchExtensionBg
});
console.log("[test]: pre searchDatasetData");
// pgvector 中查询相似的向量
const { searchRes, tokens, ...result } = await searchDatasetData({
teamId,
reRankQuery: rewriteQuery,
queries: concatQueries,
model: dataset.vectorModel,
limit: Math.min(limit, 20000),
similarity,
datasetIds: [datasetId],
searchMode,
usingReRank: usingReRank && (await checkTeamReRankPermission(teamId))
});
// push bill 更新 token 费用
const { totalPoints } = pushGenerateVectorUsage({
teamId,
tmbId,
tokens,
model: dataset.vectorModel,
source: apikey ? UsageSourceEnum.api : UsageSourceEnum.fastgpt,
...(aiExtensionResult &&
extensionModel && {
extensionModel: extensionModel.name,
extensionTokens: aiExtensionResult.tokens
})
});
// Mongodb 更新 apikey token
if (apikey) {
updateApiKeyUsage({
apikey,
totalPoints: totalPoints
});
}
return {
list: searchRes, // 存储检索结果
duration: `${((Date.now() - start) / 1000).toFixed(3)}s`, // 时长
queryExtensionModel: aiExtensionResult?.model, //
...result
};
}
export default NextAPI(handler);
函数 handler 主要是打辅助,主力在 searchDatasetData 函数中。
函数 handler 传入的配置多数都是在知识库搜索配置的参数,如下所示。
1.2 controller.ts
主要处理逻辑在 searchDatasetData 函数中,其调用 getVectorsByText 获取测试文本的向量化,在 pgvector 中查询相似度高的向量,然后,通过 mongodb 查询向量的原文。
type SearchDatasetDataProps = {
teamId: string;
model: string;
similarity?: number; // min distance
limit: number; // max Token limit
datasetIds: string[];
searchMode?: `${DatasetSearchModeEnum}`;
usingReRank?: boolean;
reRankQuery: string;
queries: string[];
};
export async function searchDatasetData(props: SearchDatasetDataProps) {
console.log("function searchDatasetData");
let {
teamId,
reRankQuery,
queries,
model,
similarity = 0,
limit: maxTokens,
searchMode = DatasetSearchModeEnum.embedding,
usingReRank = false,
datasetIds = []
} = props;
/* init params */
// 默认搜索模式是 embeddinng
searchMode = DatasetSearchModeMap[searchMode] ? searchMode : DatasetSearchModeEnum.embedding;
// 是否使用重排模型
usingReRank = usingReRank && global.reRankModels.length > 0;
// Compatible with topk limit
if (maxTokens < 50) {
maxTokens = 1500;
}
let set = new Set<string>();
let usingSimilarityFilter = false;
/* function */
// 1. countRecallLimit,根据搜索模式修改限制,分别对应三种检索方式:
const countRecallLimit = () => {
if (searchMode === DatasetSearchModeEnum.embedding) { // 语义检索
return {
embeddingLimit: 100,
fullTextLimit: 0
};
}
if (searchMode === DatasetSearchModeEnum.fullTextRecall) { // 全文检索
return {
embeddingLimit: 0,
fullTextLimit: 100
};
}
return { // 混合检索
embeddingLimit: 80,
fullTextLimit: 60
};
};
// 2. embeddingRecall
const embeddingRecall = async ({ query, limit }: { query: string; limit: number }) => {
const { vectors, tokens } = await getVectorsByText({ // 获取输入文本的向量,vectors 为转换后的向量
model: getVectorModel(model), // 从配置文件中获取 model 的配置信息
input: query,
type: 'query'
});
const { results } = await recallFromVectorStore({ // 在 pg vector 中查找相似向量
teamId,
datasetIds,
vector: vectors[0],
limit
});
// get q and a 在 Mongodb 中查找向量的文本形式
const dataList = (await MongoDatasetData.find(
{
teamId,
datasetId: { $in: datasetIds },
collectionId: { $in: Array.from(new Set(results.map((item) => item.collectionId))) },
'indexes.dataId': { $in: results.map((item) => item.id?.trim()) }
},
'datasetId collectionId q a chunkIndex indexes'
)
.populate('collectionId', 'name fileId rawLink externalFileId externalFileUrl')
.lean()) as DatasetDataWithCollectionType[];
// add score to data(It's already sorted. The first one is the one with the most points)
const concatResults = dataList.map((data) => {
const dataIdList = data.indexes.map((item) => item.dataId);
const maxScoreResult = results.find((item) => {
return dataIdList.includes(item.id);
});
return {
...data,
score: maxScoreResult?.score || 0
};
});
concatResults.sort((a, b) => b.score - a.score);
const formatResult = concatResults.map((data, index) => {
if (!data.collectionId) {
console.log('Collection is not found', data);
}
const result: SearchDataResponseItemType = {
id: String(data._id),
q: data.q,
a: data.a,
chunkIndex: data.chunkIndex,
datasetId: String(data.datasetId),
collectionId: String(data.collectionId?._id),
...getCollectionSourceData(data.collectionId),
score: [{ type: SearchScoreTypeEnum.embedding, value: data.score, index }]
};
return result;
});
return {
embeddingRecallResults: formatResult,
tokens
};
};
// 3. fullTextRecall
const fullTextRecall = async ({
query,
limit
}: {
query: string;
limit: number;
}): Promise<{
fullTextRecallResults: SearchDataResponseItemType[];
tokenLen: number;
}> => {
if (limit === 0) {
return {
fullTextRecallResults: [],
tokenLen: 0
};
}
let searchResults = (
await Promise.all(
datasetIds.map((id) =>
MongoDatasetData.find(
{
teamId,
datasetId: id,
$text: { $search: jiebaSplit({ text: query }) }
},
{
score: { $meta: 'textScore' },
_id: 1,
datasetId: 1,
collectionId: 1,
q: 1,
a: 1,
chunkIndex: 1
}
)
.sort({ score: { $meta: 'textScore' } })
.limit(limit)
.lean()
)
)
).flat() as (DatasetDataSchemaType & { score: number })[];
// resort
searchResults.sort((a, b) => b.score - a.score);
searchResults.slice(0, limit);
const collections = await MongoDatasetCollection.find(
{
_id: { $in: searchResults.map((item) => item.collectionId) }
},
'_id name fileId rawLink'
);
return {
fullTextRecallResults: searchResults.map((item, index) => {
const collection = collections.find((col) => String(col._id) === String(item.collectionId));
return {
id: String(item._id),
datasetId: String(item.datasetId),
collectionId: String(item.collectionId),
...getCollectionSourceData(collection),
q: item.q,
a: item.a,
chunkIndex: item.chunkIndex,
indexes: item.indexes,
score: [{ type: SearchScoreTypeEnum.fullText, value: item.score, index }]
};
}),
tokenLen: 0
};
};
// 4. reRankSearchResult
const reRankSearchResult = async ({
data,
query
}: {
data: SearchDataResponseItemType[];
query: string;
}): Promise<SearchDataResponseItemType[]> => {
try {
const results = await reRankRecall({
query,
documents: data.map((item) => ({
id: item.id,
text: `${item.q}\n${item.a}`
}))
});
if (results.length === 0) {
usingReRank = false;
return [];
}
// add new score to data
const mergeResult = results
.map((item, index) => {
const target = data.find((dataItem) => dataItem.id === item.id);
if (!target) return null;
const score = item.score || 0;
return {
...target,
score: [{ type: SearchScoreTypeEnum.reRank, value: score, index }]
};
})
.filter(Boolean) as SearchDataResponseItemType[];
return mergeResult;
} catch (error) {
usingReRank = false;
return [];
}
};
// 5. filterResultsByMaxTokens
const filterResultsByMaxTokens = async (
list: SearchDataResponseItemType[],
maxTokens: number
) => {
const results: SearchDataResponseItemType[] = [];
let totalTokens = 0;
for await (const item of list) {
totalTokens += await countPromptTokens(item.q + item.a);
if (totalTokens > maxTokens + 500) {
break;
}
results.push(item);
if (totalTokens > maxTokens) {
break;
}
}
return results.length === 0 ? list.slice(0, 1) : results;
};
// 6. multiQueryRecall 首先,将 query 转换为 vector,然后,在 pgvector 中检索相似,最后在 mongodb 查找 vector 对应的文本,处理后返回。
const multiQueryRecall = async ({
embeddingLimit,
fullTextLimit
}: {
embeddingLimit: number;
fullTextLimit: number;
}) => {
// multi query recall
const embeddingRecallResList: SearchDataResponseItemType[][] = [];
const fullTextRecallResList: SearchDataResponseItemType[][] = [];
let totalTokens = 0;
await Promise.all(
queries.map(async (query) => { // 遍历多个 query
const [{ tokens, embeddingRecallResults }, { fullTextRecallResults }] = await Promise.all([
embeddingRecall({
query,
limit: embeddingLimit
}),
fullTextRecall({
query,
limit: fullTextLimit
})
]);
totalTokens += tokens;
embeddingRecallResList.push(embeddingRecallResults);
fullTextRecallResList.push(fullTextRecallResults);
})
);
// rrf concat
const rrfEmbRecall = datasetSearchResultConcat(
embeddingRecallResList.map((list) => ({ k: 60, list }))
).slice(0, embeddingLimit);
const rrfFTRecall = datasetSearchResultConcat(
fullTextRecallResList.map((list) => ({ k: 60, list }))
).slice(0, fullTextLimit);
return {
tokens: totalTokens,
embeddingRecallResults: rrfEmbRecall,
fullTextRecallResults: rrfFTRecall
};
};
// 上面都是函数的定义
/* main step */
// count limit
const { embeddingLimit, fullTextLimit } = countRecallLimit();
// recall
const { embeddingRecallResults, fullTextRecallResults, tokens } = await multiQueryRecall({
embeddingLimit,
fullTextLimit
});
// ReRank results
const reRankResults = await (async () => {
if (!usingReRank) return [];
set = new Set<string>(embeddingRecallResults.map((item) => item.id));
const concatRecallResults = embeddingRecallResults.concat(
fullTextRecallResults.filter((item) => !set.has(item.id))
);
// remove same q and a data
set = new Set<string>();
const filterSameDataResults = concatRecallResults.filter((item) => {
// 删除所有的标点符号与空格等,只对文本进行比较
const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));
if (set.has(str)) return false;
set.add(str);
return true;
});
return reRankSearchResult({
query: reRankQuery,
data: filterSameDataResults
});
})();
// embedding recall and fullText recall rrf concat
const rrfConcatResults = datasetSearchResultConcat([
{ k: 60, list: embeddingRecallResults },
{ k: 60, list: fullTextRecallResults },
{ k: 58, list: reRankResults }
]);
// remove same q and a data
set = new Set<string>();
const filterSameDataResults = rrfConcatResults.filter((item) => {
// 删除所有的标点符号与空格等,只对文本进行比较
const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));
if (set.has(str)) return false;
set.add(str);
return true;
});
// score filter
const scoreFilter = (() => {
if (usingReRank) {
usingSimilarityFilter = true;
return filterSameDataResults.filter((item) => {
const reRankScore = item.score.find((item) => item.type === SearchScoreTypeEnum.reRank);
if (reRankScore && reRankScore.value < similarity) return false;
return true;
});
}
if (searchMode === DatasetSearchModeEnum.embedding) {
usingSimilarityFilter = true;
return filterSameDataResults.filter((item) => {
const embeddingScore = item.score.find(
(item) => item.type === SearchScoreTypeEnum.embedding
);
if (embeddingScore && embeddingScore.value < similarity) return false;
return true;
});
}
return filterSameDataResults;
})();
return {
searchRes: await filterResultsByMaxTokens(scoreFilter, maxTokens),
tokens,
searchMode,
limit: maxTokens,
similarity,
usingReRank,
usingSimilarityFilter
};
}
multiQueryRecall : 首先,将 query 转换为 vector,然后,在 pgvector 中检索相似,最后在 mongodb 查找 vector 对应的文本,处理后返回。主要在 embeddingRecall 函数中实现。
getVectorsByText : 负责将搜索的问题转换为向量表示;
recallFromVectorStore : 在 pg vector 中查找相似向量;
MongoDatasetData.find :将 recallFromVectorStore 查询出的相似向量在 mongodb 中找出原文本。
其他内容后面再详细展开介绍。
参考链接:
[1] FastGPT源码深度剖析:混合检索及语料召回逻辑 - 技术栈