message search (#60)
This commit is contained in:
@@ -2,9 +2,9 @@ package windowsv3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/sjzar/chatlog/internal/errors"
|
||||
"github.com/sjzar/chatlog/internal/model"
|
||||
"github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm"
|
||||
"github.com/sjzar/chatlog/pkg/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -221,21 +222,38 @@ func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []Mes
|
||||
return dbs
|
||||
}
|
||||
|
||||
// GetMessages 实现 DataSource 接口的 GetMessages 方法
|
||||
func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) {
|
||||
func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) {
|
||||
if talker == "" {
|
||||
return nil, errors.ErrTalkerEmpty
|
||||
}
|
||||
|
||||
// 解析talker参数,支持多个talker(以英文逗号分隔)
|
||||
talkers := util.Str2List(talker, ",")
|
||||
if len(talkers) == 0 {
|
||||
return nil, errors.ErrTalkerEmpty
|
||||
}
|
||||
|
||||
// 找到时间范围内的数据库文件
|
||||
dbInfos := ds.getDBInfosForTimeRange(startTime, endTime)
|
||||
if len(dbInfos) == 0 {
|
||||
return nil, errors.TimeRangeNotFound(startTime, endTime)
|
||||
}
|
||||
|
||||
if len(dbInfos) == 1 {
|
||||
// LIMIT 和 OFFSET 逻辑在单文件情况下可以直接在 SQL 里处理
|
||||
return ds.getMessagesSingleFile(ctx, dbInfos[0], startTime, endTime, talker, limit, offset)
|
||||
// 解析sender参数,支持多个发送者(以英文逗号分隔)
|
||||
senders := util.Str2List(sender, ",")
|
||||
|
||||
// 预编译正则表达式(如果有keyword)
|
||||
var regex *regexp.Regexp
|
||||
if keyword != "" {
|
||||
var err error
|
||||
regex, err = regexp.Compile(keyword)
|
||||
if err != nil {
|
||||
return nil, errors.QueryFailed("invalid regex pattern", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 从每个相关数据库中查询消息
|
||||
totalMessages := []*model.Message{}
|
||||
filteredMessages := []*model.Message{}
|
||||
|
||||
for _, dbInfo := range dbInfos {
|
||||
// 检查上下文是否已取消
|
||||
@@ -249,172 +267,137 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
|
||||
continue
|
||||
}
|
||||
|
||||
messages, err := ds.getMessagesFromDB(ctx, db, dbInfo, startTime, endTime, talker)
|
||||
if err != nil {
|
||||
log.Err(err).Msgf("从数据库 %s 获取消息失败", dbInfo.FilePath)
|
||||
continue
|
||||
}
|
||||
// 对每个talker进行查询
|
||||
for _, talkerItem := range talkers {
|
||||
// 构建查询条件
|
||||
conditions := []string{"Sequence >= ? AND Sequence <= ?"}
|
||||
args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000}
|
||||
|
||||
totalMessages = append(totalMessages, messages...)
|
||||
// 添加talker条件
|
||||
talkerID, ok := dbInfo.TalkerMap[talkerItem]
|
||||
if ok {
|
||||
conditions = append(conditions, "TalkerId = ?")
|
||||
args = append(args, talkerID)
|
||||
} else {
|
||||
conditions = append(conditions, "StrTalker = ?")
|
||||
args = append(args, talkerItem)
|
||||
}
|
||||
|
||||
if limit+offset > 0 && len(totalMessages) >= limit+offset {
|
||||
break
|
||||
query := fmt.Sprintf(`
|
||||
SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender,
|
||||
Type, SubType, StrContent, CompressContent, BytesExtra
|
||||
FROM MSG
|
||||
WHERE %s
|
||||
ORDER BY Sequence ASC
|
||||
`, strings.Join(conditions, " AND "))
|
||||
|
||||
// 执行查询
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
// 如果表不存在,跳过此talker
|
||||
if strings.Contains(err.Error(), "no such table") {
|
||||
continue
|
||||
}
|
||||
log.Err(err).Msgf("从数据库 %s 查询消息失败", dbInfo.FilePath)
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理查询结果,在读取时进行过滤
|
||||
for rows.Next() {
|
||||
var msg model.MessageV3
|
||||
var compressContent []byte
|
||||
var bytesExtra []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&msg.MsgSvrID,
|
||||
&msg.Sequence,
|
||||
&msg.CreateTime,
|
||||
&msg.StrTalker,
|
||||
&msg.IsSender,
|
||||
&msg.Type,
|
||||
&msg.SubType,
|
||||
&msg.StrContent,
|
||||
&compressContent,
|
||||
&bytesExtra,
|
||||
)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
return nil, errors.ScanRowFailed(err)
|
||||
}
|
||||
msg.CompressContent = compressContent
|
||||
msg.BytesExtra = bytesExtra
|
||||
|
||||
// 将消息转换为标准格式
|
||||
message := msg.Wrap()
|
||||
|
||||
// 应用sender过滤
|
||||
if len(senders) > 0 {
|
||||
senderMatch := false
|
||||
for _, s := range senders {
|
||||
if message.Sender == s {
|
||||
senderMatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !senderMatch {
|
||||
continue // 不匹配sender,跳过此消息
|
||||
}
|
||||
}
|
||||
|
||||
// 应用keyword过滤
|
||||
if regex != nil {
|
||||
plainText := message.PlainTextContent()
|
||||
if !regex.MatchString(plainText) {
|
||||
continue // 不匹配keyword,跳过此消息
|
||||
}
|
||||
}
|
||||
|
||||
// 通过所有过滤条件,保留此消息
|
||||
filteredMessages = append(filteredMessages, message)
|
||||
|
||||
// 检查是否已经满足分页处理数量
|
||||
if limit > 0 && len(filteredMessages) >= offset+limit {
|
||||
// 已经获取了足够的消息,可以提前返回
|
||||
rows.Close()
|
||||
|
||||
// 对所有消息按时间排序
|
||||
sort.Slice(filteredMessages, func(i, j int) bool {
|
||||
return filteredMessages[i].Seq < filteredMessages[j].Seq
|
||||
})
|
||||
|
||||
// 处理分页
|
||||
if offset >= len(filteredMessages) {
|
||||
return []*model.Message{}, nil
|
||||
}
|
||||
end := offset + limit
|
||||
if end > len(filteredMessages) {
|
||||
end = len(filteredMessages)
|
||||
}
|
||||
return filteredMessages[offset:end], nil
|
||||
}
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// 对所有消息按时间排序
|
||||
sort.Slice(totalMessages, func(i, j int) bool {
|
||||
return totalMessages[i].Seq < totalMessages[j].Seq
|
||||
sort.Slice(filteredMessages, func(i, j int) bool {
|
||||
return filteredMessages[i].Seq < filteredMessages[j].Seq
|
||||
})
|
||||
|
||||
// 处理分页
|
||||
if limit > 0 {
|
||||
if offset >= len(totalMessages) {
|
||||
if offset >= len(filteredMessages) {
|
||||
return []*model.Message{}, nil
|
||||
}
|
||||
end := offset + limit
|
||||
if end > len(totalMessages) {
|
||||
end = len(totalMessages)
|
||||
if end > len(filteredMessages) {
|
||||
end = len(filteredMessages)
|
||||
}
|
||||
return totalMessages[offset:end], nil
|
||||
return filteredMessages[offset:end], nil
|
||||
}
|
||||
|
||||
return totalMessages, nil
|
||||
}
|
||||
|
||||
// getMessagesSingleFile 从单个数据库文件获取消息
|
||||
func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) {
|
||||
db, err := ds.dbm.OpenDB(dbInfo.FilePath)
|
||||
if err != nil {
|
||||
return nil, errors.DBConnectFailed(dbInfo.FilePath, nil)
|
||||
}
|
||||
|
||||
// 构建查询条件
|
||||
conditions := []string{"Sequence >= ? AND Sequence <= ?"}
|
||||
args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000}
|
||||
if len(talker) > 0 {
|
||||
// TalkerId 有索引,优先使用
|
||||
talkerID, ok := dbInfo.TalkerMap[talker]
|
||||
if ok {
|
||||
conditions = append(conditions, "TalkerId = ?")
|
||||
args = append(args, talkerID)
|
||||
} else {
|
||||
conditions = append(conditions, "StrTalker = ?")
|
||||
args = append(args, talker)
|
||||
}
|
||||
}
|
||||
query := fmt.Sprintf(`
|
||||
SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender,
|
||||
Type, SubType, StrContent, CompressContent, BytesExtra
|
||||
FROM MSG
|
||||
WHERE %s
|
||||
ORDER BY Sequence ASC
|
||||
`, strings.Join(conditions, " AND "))
|
||||
|
||||
if limit > 0 {
|
||||
query += fmt.Sprintf(" LIMIT %d", limit)
|
||||
|
||||
if offset > 0 {
|
||||
query += fmt.Sprintf(" OFFSET %d", offset)
|
||||
}
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, errors.QueryFailed(query, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 处理查询结果
|
||||
totalMessages := []*model.Message{}
|
||||
for rows.Next() {
|
||||
var msg model.MessageV3
|
||||
var compressContent []byte
|
||||
var bytesExtra []byte
|
||||
err := rows.Scan(
|
||||
&msg.MsgSvrID,
|
||||
&msg.Sequence,
|
||||
&msg.CreateTime,
|
||||
&msg.StrTalker,
|
||||
&msg.IsSender,
|
||||
&msg.Type,
|
||||
&msg.SubType,
|
||||
&msg.StrContent,
|
||||
&compressContent,
|
||||
&bytesExtra,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.ScanRowFailed(err)
|
||||
}
|
||||
msg.CompressContent = compressContent
|
||||
msg.BytesExtra = bytesExtra
|
||||
totalMessages = append(totalMessages, msg.Wrap())
|
||||
}
|
||||
return totalMessages, nil
|
||||
}
|
||||
|
||||
// getMessagesFromDB 从数据库获取消息
|
||||
func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string) ([]*model.Message, error) {
|
||||
// 构建查询条件
|
||||
conditions := []string{"Sequence >= ? AND Sequence <= ?"}
|
||||
args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000}
|
||||
|
||||
if len(talker) > 0 {
|
||||
talkerID, ok := dbInfo.TalkerMap[talker]
|
||||
if ok {
|
||||
conditions = append(conditions, "TalkerId = ?")
|
||||
args = append(args, talkerID)
|
||||
} else {
|
||||
conditions = append(conditions, "StrTalker = ?")
|
||||
args = append(args, talker)
|
||||
}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender,
|
||||
Type, SubType, StrContent, CompressContent, BytesExtra
|
||||
FROM MSG
|
||||
WHERE %s
|
||||
ORDER BY Sequence ASC
|
||||
`, strings.Join(conditions, " AND "))
|
||||
|
||||
// 执行查询
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, errors.QueryFailed(query, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 处理查询结果
|
||||
messages := []*model.Message{}
|
||||
for rows.Next() {
|
||||
var msg model.MessageV3
|
||||
var compressContent []byte
|
||||
var bytesExtra []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&msg.MsgSvrID,
|
||||
&msg.Sequence,
|
||||
&msg.CreateTime,
|
||||
&msg.StrTalker,
|
||||
&msg.IsSender,
|
||||
&msg.Type,
|
||||
&msg.SubType,
|
||||
&msg.StrContent,
|
||||
&compressContent,
|
||||
&bytesExtra,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.ScanRowFailed(err)
|
||||
}
|
||||
msg.CompressContent = compressContent
|
||||
msg.BytesExtra = bytesExtra
|
||||
|
||||
messages = append(messages, msg.Wrap())
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
return filteredMessages, nil
|
||||
}
|
||||
|
||||
// GetContacts 实现获取联系人信息的方法
|
||||
|
||||
Reference in New Issue
Block a user