322 lines
7.8 KiB
Go
322 lines
7.8 KiB
Go
package wechatdb
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/sjzar/chatlog/pkg/model"
|
|
"github.com/sjzar/chatlog/pkg/util"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
const (
|
|
MessageFileV3 = "^MSG([0-9]?[0-9])?\\.db$"
|
|
MessageFileV4 = "^messages_([0-9]?[0-9])+\\.db$"
|
|
)
|
|
|
|
type Message struct {
|
|
version int
|
|
files []MsgDBInfo
|
|
dbs map[string]*sql.DB
|
|
}
|
|
|
|
type MsgDBInfo struct {
|
|
FilePath string
|
|
StartTime time.Time
|
|
EndTime time.Time
|
|
TalkerMap map[string]int
|
|
}
|
|
|
|
func NewMessage(path string, version int) (*Message, error) {
|
|
m := &Message{
|
|
version: version,
|
|
files: make([]MsgDBInfo, 0),
|
|
dbs: make(map[string]*sql.DB),
|
|
}
|
|
|
|
// 查找所有 MSG[0-13].db 文件
|
|
files, err := util.FindFilesWithPatterns(path, MessageFileV3, true)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查找数据库文件失败: %v", err)
|
|
}
|
|
|
|
if len(files) == 0 {
|
|
return nil, fmt.Errorf("未找到任何数据库文件: %s", path)
|
|
}
|
|
|
|
// 处理每个数据库文件
|
|
for _, filePath := range files {
|
|
// 连接数据库
|
|
db, err := sql.Open("sqlite3", filePath)
|
|
if err != nil {
|
|
log.Printf("警告: 连接数据库 %s 失败: %v", filePath, err)
|
|
continue
|
|
}
|
|
|
|
// 获取 DBInfo 表中的开始时间
|
|
// 首先检查表结构
|
|
var startTime time.Time
|
|
|
|
// 尝试从 DBInfo 表中查找 Start Time 对应的记录
|
|
rows, err := db.Query("SELECT tableIndex, tableVersion, tableDesc FROM DBInfo")
|
|
if err != nil {
|
|
log.Printf("警告: 查询数据库 %s 的 DBInfo 表失败: %v", filePath, err)
|
|
db.Close()
|
|
continue
|
|
}
|
|
|
|
for rows.Next() {
|
|
var tableIndex int
|
|
var tableVersion int64
|
|
var tableDesc string
|
|
|
|
if err := rows.Scan(&tableIndex, &tableVersion, &tableDesc); err != nil {
|
|
log.Printf("警告: 扫描 DBInfo 行失败: %v", err)
|
|
continue
|
|
}
|
|
|
|
// 查找描述为 "Start Time" 的记录
|
|
if strings.Contains(tableDesc, "Start Time") {
|
|
startTime = time.Unix(tableVersion/1000, (tableVersion%1000)*1000000)
|
|
break
|
|
}
|
|
}
|
|
rows.Close()
|
|
|
|
// 组织 TalkerMap
|
|
talkerMap := make(map[string]int)
|
|
rows, err = db.Query("SELECT UsrName FROM Name2ID")
|
|
if err != nil {
|
|
log.Printf("警告: 查询数据库 %s 的 Name2ID 表失败: %v", filePath, err)
|
|
db.Close()
|
|
continue
|
|
}
|
|
|
|
i := 1
|
|
for rows.Next() {
|
|
var userName string
|
|
if err := rows.Scan(&userName); err != nil {
|
|
log.Printf("警告: 扫描 Name2ID 行失败: %v", err)
|
|
continue
|
|
}
|
|
talkerMap[userName] = i
|
|
i++
|
|
}
|
|
|
|
// 保存数据库信息
|
|
m.files = append(m.files, MsgDBInfo{
|
|
FilePath: filePath,
|
|
StartTime: startTime,
|
|
TalkerMap: talkerMap,
|
|
})
|
|
|
|
// 保存数据库连接
|
|
m.dbs[filePath] = db
|
|
}
|
|
|
|
// 按照 StartTime 排序数据库文件
|
|
sort.Slice(m.files, func(i, j int) bool {
|
|
return m.files[i].StartTime.Before(m.files[j].StartTime)
|
|
})
|
|
|
|
for i := range m.files {
|
|
if i == len(m.files)-1 {
|
|
m.files[i].EndTime = time.Now()
|
|
} else {
|
|
m.files[i].EndTime = m.files[i+1].StartTime
|
|
}
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
// GetMessages 根据时间段和 talker 查询聊天记录
|
|
func (m *Message) GetMessages(startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) {
|
|
// 找到时间范围内的数据库文件
|
|
dbInfos := m.getDBInfosForTimeRange(startTime, endTime)
|
|
if len(dbInfos) == 0 {
|
|
return nil, fmt.Errorf("未找到时间范围 %v 到 %v 内的数据库文件", startTime, endTime)
|
|
}
|
|
|
|
if len(dbInfos) == 1 {
|
|
// LIMIT 和 OFFSET 逻辑在单文件情况下可以直接在 SQL 里处理
|
|
return m.getMessagesSingleFile(dbInfos[0], startTime, endTime, talker, limit, offset)
|
|
}
|
|
|
|
// 从每个相关数据库中查询消息
|
|
totalMessages := []*model.Message{}
|
|
|
|
for _, dbInfo := range dbInfos {
|
|
db, ok := m.dbs[dbInfo.FilePath]
|
|
if !ok {
|
|
log.Printf("警告: 数据库 %s 未打开", dbInfo.FilePath)
|
|
continue
|
|
}
|
|
|
|
// 构建查询条件
|
|
// 使用 Sequence 查询,有索引
|
|
conditions := []string{"Sequence >= ? AND Sequence <= ?"}
|
|
args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000}
|
|
|
|
if len(talker) > 0 {
|
|
talkerID, ok := dbInfo.TalkerMap[talker]
|
|
if ok {
|
|
conditions = append(conditions, "TalkerId = ?")
|
|
args = append(args, talkerID)
|
|
} else {
|
|
conditions = append(conditions, "StrTalker = ?")
|
|
args = append(args, talker)
|
|
}
|
|
}
|
|
|
|
query := fmt.Sprintf(`
|
|
SELECT Sequence, CreateTime, TalkerId, StrTalker, IsSender,
|
|
Type, SubType, StrContent, CompressContent, BytesExtra
|
|
FROM MSG
|
|
WHERE %s
|
|
ORDER BY Sequence ASC
|
|
`, strings.Join(conditions, " AND "))
|
|
|
|
// 执行查询
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
log.Printf("警告: 查询数据库 %s 失败: %v", dbInfo.FilePath, err)
|
|
continue
|
|
}
|
|
|
|
// 处理查询结果
|
|
for rows.Next() {
|
|
var msg model.MessageV3
|
|
var compressContent []byte
|
|
var bytesExtra []byte
|
|
|
|
err := rows.Scan(
|
|
&msg.Sequence,
|
|
&msg.CreateTime,
|
|
&msg.TalkerID,
|
|
&msg.StrTalker,
|
|
&msg.IsSender,
|
|
&msg.Type,
|
|
&msg.SubType,
|
|
&msg.StrContent,
|
|
&compressContent,
|
|
&bytesExtra,
|
|
)
|
|
if err != nil {
|
|
log.Printf("警告: 扫描消息行失败: %v", err)
|
|
continue
|
|
}
|
|
msg.CompressContent = compressContent
|
|
msg.BytesExtra = bytesExtra
|
|
|
|
totalMessages = append(totalMessages, msg.Wrap())
|
|
}
|
|
rows.Close()
|
|
|
|
if limit+offset > 0 && len(totalMessages) >= limit+offset {
|
|
break
|
|
}
|
|
}
|
|
|
|
// 对所有消息按时间排序
|
|
sort.Slice(totalMessages, func(i, j int) bool {
|
|
return totalMessages[i].Sequence < totalMessages[j].Sequence
|
|
})
|
|
|
|
// FIXME limit 和 offset 逻辑,在多文件边界条件下不好处理,直接查询全量数据后在进程里处理
|
|
if limit > 0 {
|
|
if offset >= len(totalMessages) {
|
|
return []*model.Message{}, nil
|
|
}
|
|
end := offset + limit
|
|
if end > len(totalMessages) || limit == 0 {
|
|
end = len(totalMessages)
|
|
}
|
|
return totalMessages[offset:end], nil
|
|
}
|
|
|
|
return totalMessages, nil
|
|
|
|
}
|
|
|
|
func (m *Message) getMessagesSingleFile(dbInfo MsgDBInfo, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) {
|
|
// 构建查询条件
|
|
// 使用 Sequence 查询,有索引
|
|
conditions := []string{"Sequence >= ? AND Sequence <= ?"}
|
|
args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000}
|
|
if len(talker) > 0 {
|
|
// TalkerId 有索引,优先使用
|
|
talkerID, ok := dbInfo.TalkerMap[talker]
|
|
if ok {
|
|
conditions = append(conditions, "TalkerId = ?")
|
|
args = append(args, talkerID)
|
|
} else {
|
|
conditions = append(conditions, "StrTalker = ?")
|
|
args = append(args, talker)
|
|
}
|
|
}
|
|
query := fmt.Sprintf(`
|
|
SELECT Sequence, CreateTime, TalkerId, StrTalker, IsSender,
|
|
Type, SubType, StrContent, CompressContent, BytesExtra
|
|
FROM MSG
|
|
WHERE %s
|
|
ORDER BY Sequence ASC
|
|
`, strings.Join(conditions, " AND "))
|
|
|
|
if limit > 0 {
|
|
query += fmt.Sprintf(" LIMIT %d", limit)
|
|
|
|
if offset > 0 {
|
|
query += fmt.Sprintf(" OFFSET %d", offset)
|
|
}
|
|
}
|
|
|
|
// 执行查询
|
|
rows, err := m.dbs[dbInfo.FilePath].Query(query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询数据库 %s 失败: %v", dbInfo.FilePath, err)
|
|
}
|
|
defer rows.Close()
|
|
// 处理查询结果
|
|
totalMessages := []*model.Message{}
|
|
for rows.Next() {
|
|
var msg model.MessageV3
|
|
var compressContent []byte
|
|
var bytesExtra []byte
|
|
err := rows.Scan(
|
|
&msg.Sequence,
|
|
&msg.CreateTime,
|
|
&msg.TalkerID,
|
|
&msg.StrTalker,
|
|
&msg.IsSender,
|
|
&msg.Type,
|
|
&msg.SubType,
|
|
&msg.StrContent,
|
|
&compressContent,
|
|
&bytesExtra,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("扫描消息行失败: %v", err)
|
|
}
|
|
msg.CompressContent = compressContent
|
|
msg.BytesExtra = bytesExtra
|
|
totalMessages = append(totalMessages, msg.Wrap())
|
|
}
|
|
return totalMessages, nil
|
|
}
|
|
|
|
func (m *Message) getDBInfosForTimeRange(startTime, endTime time.Time) []MsgDBInfo {
|
|
var dbs []MsgDBInfo
|
|
for _, info := range m.files {
|
|
if info.StartTime.Before(endTime) && info.EndTime.After(startTime) {
|
|
dbs = append(dbs, info)
|
|
}
|
|
}
|
|
return dbs
|
|
}
|