adjust message handing (#22)

* adjust message handing

* mcp required args
This commit is contained in:
Sarv
2025-04-09 00:02:55 +08:00
committed by GitHub
parent c12ee8bfce
commit b4378a63a3
11 changed files with 688 additions and 505 deletions

View File

@@ -30,7 +30,6 @@ type MessageDBInfo struct {
FilePath string
StartTime time.Time
EndTime time.Time
ID2Name map[int]string
}
type DataSource struct {
@@ -99,32 +98,10 @@ func (ds *DataSource) initMessageDbs(path string) error {
}
startTime = time.Unix(timestamp, 0)
// 获取 ID2Name 映射
id2Name := make(map[int]string)
rows, err := db.Query("SELECT user_name FROM Name2Id")
if err != nil {
log.Err(err).Msgf("获取数据库 %s 的 Name2Id 表失败", filePath)
db.Close()
continue
}
i := 1
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
log.Err(err).Msgf("数据库 %s 扫描 Name2Id 行失败", filePath)
continue
}
id2Name[i] = name
i++
}
rows.Close()
// 保存数据库信息
ds.messageFiles = append(ds.messageFiles, MessageDBInfo{
FilePath: filePath,
StartTime: startTime,
ID2Name: id2Name,
})
// 保存数据库连接
@@ -253,7 +230,7 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
// 对所有消息按时间排序
sort.Slice(totalMessages, func(i, j int) bool {
return totalMessages[i].Sequence < totalMessages[j].Sequence
return totalMessages[i].Seq < totalMessages[j].Seq
})
// 处理分页
@@ -288,10 +265,11 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD
args := []interface{}{startTime.Unix(), endTime.Unix()}
query := fmt.Sprintf(`
SELECT sort_seq, local_type, real_sender_id, create_time, message_content, packed_info_data, status
FROM %s
SELECT m.sort_seq, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status
FROM %s m
LEFT JOIN Name2Id n ON m.real_sender_id = n.rowid
WHERE %s
ORDER BY sort_seq ASC
ORDER BY m.sort_seq ASC
`, tableName, strings.Join(conditions, " AND "))
if limit > 0 {
@@ -310,14 +288,13 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD
// 处理查询结果
messages := []*model.Message{}
isChatRoom := strings.HasSuffix(talker, "@chatroom")
for rows.Next() {
var msg model.MessageV4
err := rows.Scan(
&msg.SortSeq,
&msg.LocalType,
&msg.RealSenderID,
&msg.UserName,
&msg.CreateTime,
&msg.MessageContent,
&msg.PackedInfoData,
@@ -327,7 +304,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD
return nil, errors.ScanRowFailed(err)
}
messages = append(messages, msg.Wrap(dbInfo.ID2Name, isChatRoom))
messages = append(messages, msg.Wrap(talker))
}
return messages, nil
@@ -359,10 +336,11 @@ func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo
args := []interface{}{startTime.Unix(), endTime.Unix()}
query := fmt.Sprintf(`
SELECT sort_seq, local_type, real_sender_id, create_time, message_content, packed_info_data, status
FROM %s
SELECT m.sort_seq, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status
FROM %s m
LEFT JOIN Name2Id n ON m.real_sender_id = n.rowid
WHERE %s
ORDER BY sort_seq ASC
ORDER BY m.sort_seq ASC
`, tableName, strings.Join(conditions, " AND "))
// 执行查询
@@ -378,14 +356,13 @@ func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo
// 处理查询结果
messages := []*model.Message{}
isChatRoom := strings.HasSuffix(talker, "@chatroom")
for rows.Next() {
var msg model.MessageV4
err := rows.Scan(
&msg.SortSeq,
&msg.LocalType,
&msg.RealSenderID,
&msg.UserName,
&msg.CreateTime,
&msg.MessageContent,
&msg.PackedInfoData,
@@ -395,7 +372,7 @@ func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo
return nil, errors.ScanRowFailed(err)
}
messages = append(messages, msg.Wrap(dbInfo.ID2Name, isChatRoom))
messages = append(messages, msg.Wrap(talker))
}
return messages, nil