Files
chatlog/internal/wechatdb/datasource/v4/datasource.go
2025-04-19 03:06:05 +08:00

723 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package v4
import (
"context"
"crypto/md5"
"database/sql"
"encoding/hex"
"fmt"
"regexp"
"sort"
"strings"
"time"
"github.com/fsnotify/fsnotify"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
"github.com/sjzar/chatlog/internal/errors"
"github.com/sjzar/chatlog/internal/model"
"github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm"
"github.com/sjzar/chatlog/pkg/util"
)
const (
Message = "message"
Contact = "contact"
Session = "session"
Media = "media"
Voice = "voice"
)
var Groups = []*dbm.Group{
{
Name: Message,
Pattern: `^message_([0-9]?[0-9])?\.db$`,
BlackList: []string{},
},
{
Name: Contact,
Pattern: `^contact\.db$`,
BlackList: []string{},
},
{
Name: Session,
Pattern: `session\.db$`,
BlackList: []string{},
},
{
Name: Media,
Pattern: `^hardlink\.db$`,
BlackList: []string{},
},
{
Name: Voice,
Pattern: `^media_([0-9]?[0-9])?\.db$`,
BlackList: []string{},
},
}
// MessageDBInfo 存储消息数据库的信息
type MessageDBInfo struct {
FilePath string
StartTime time.Time
EndTime time.Time
}
type DataSource struct {
path string
dbm *dbm.DBManager
// 消息数据库信息
messageInfos []MessageDBInfo
}
func New(path string) (*DataSource, error) {
ds := &DataSource{
path: path,
dbm: dbm.NewDBManager(path),
messageInfos: make([]MessageDBInfo, 0),
}
for _, g := range Groups {
ds.dbm.AddGroup(g)
}
if err := ds.dbm.Start(); err != nil {
return nil, err
}
if err := ds.initMessageDbs(); err != nil {
return nil, errors.DBInitFailed(err)
}
ds.dbm.AddCallback(Message, func(event fsnotify.Event) error {
if !event.Op.Has(fsnotify.Create) {
return nil
}
if err := ds.initMessageDbs(); err != nil {
log.Err(err).Msgf("Failed to reinitialize message DBs: %s", event.Name)
}
return nil
})
return ds, nil
}
func (ds *DataSource) SetCallback(name string, callback func(event fsnotify.Event) error) error {
if name == "chatroom" {
name = Contact
}
return ds.dbm.AddCallback(name, callback)
}
func (ds *DataSource) initMessageDbs() error {
dbPaths, err := ds.dbm.GetDBPath(Message)
if err != nil {
if strings.Contains(err.Error(), "db file not found") {
ds.messageInfos = make([]MessageDBInfo, 0)
return nil
}
return err
}
// 处理每个数据库文件
infos := make([]MessageDBInfo, 0)
for _, filePath := range dbPaths {
db, err := ds.dbm.OpenDB(filePath)
if err != nil {
log.Err(err).Msgf("获取数据库 %s 失败", filePath)
continue
}
// 获取 Timestamp 表中的开始时间
var startTime time.Time
var timestamp int64
row := db.QueryRow("SELECT timestamp FROM Timestamp LIMIT 1")
if err := row.Scan(&timestamp); err != nil {
log.Err(err).Msgf("获取数据库 %s 的时间戳失败", filePath)
continue
}
startTime = time.Unix(timestamp, 0)
// 保存数据库信息
infos = append(infos, MessageDBInfo{
FilePath: filePath,
StartTime: startTime,
})
}
// 按照 StartTime 排序数据库文件
sort.Slice(infos, func(i, j int) bool {
return infos[i].StartTime.Before(infos[j].StartTime)
})
// 设置结束时间
for i := range infos {
if i == len(infos)-1 {
infos[i].EndTime = time.Now()
} else {
infos[i].EndTime = infos[i+1].StartTime
}
}
ds.messageInfos = infos
return nil
}
// getDBInfosForTimeRange 获取时间范围内的数据库信息
func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []MessageDBInfo {
var dbs []MessageDBInfo
for _, info := range ds.messageInfos {
if info.StartTime.Before(endTime) && info.EndTime.After(startTime) {
dbs = append(dbs, info)
}
}
return dbs
}
func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) {
if talker == "" {
return nil, errors.ErrTalkerEmpty
}
// 解析talker参数支持多个talker以英文逗号分隔
talkers := util.Str2List(talker, ",")
if len(talkers) == 0 {
return nil, errors.ErrTalkerEmpty
}
// 找到时间范围内的数据库文件
dbInfos := ds.getDBInfosForTimeRange(startTime, endTime)
if len(dbInfos) == 0 {
return nil, errors.TimeRangeNotFound(startTime, endTime)
}
// 解析sender参数支持多个发送者以英文逗号分隔
senders := util.Str2List(sender, ",")
// 预编译正则表达式如果有keyword
var regex *regexp.Regexp
if keyword != "" {
var err error
regex, err = regexp.Compile(keyword)
if err != nil {
return nil, errors.QueryFailed("invalid regex pattern", err)
}
}
// 从每个相关数据库中查询消息,并在读取时进行过滤
filteredMessages := []*model.Message{}
for _, dbInfo := range dbInfos {
// 检查上下文是否已取消
if err := ctx.Err(); err != nil {
return nil, err
}
db, err := ds.dbm.OpenDB(dbInfo.FilePath)
if err != nil {
log.Error().Msgf("数据库 %s 未打开", dbInfo.FilePath)
continue
}
// 对每个talker进行查询
for _, talkerItem := range talkers {
// 构建表名
_talkerMd5Bytes := md5.Sum([]byte(talkerItem))
talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:])
tableName := "Msg_" + talkerMd5
// 检查表是否存在
var exists bool
err = db.QueryRowContext(ctx,
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
tableName).Scan(&exists)
if err != nil {
if err == sql.ErrNoRows {
// 表不存在继续下一个talker
continue
}
return nil, errors.QueryFailed("", err)
}
// 构建查询条件
conditions := []string{"create_time >= ? AND create_time <= ?"}
args := []interface{}{startTime.Unix(), endTime.Unix()}
log.Debug().Msgf("Table name: %s", tableName)
log.Debug().Msgf("Start time: %d, End time: %d", startTime.Unix(), endTime.Unix())
query := fmt.Sprintf(`
SELECT m.sort_seq, m.server_id, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status
FROM %s m
LEFT JOIN Name2Id n ON m.real_sender_id = n.rowid
WHERE %s
ORDER BY m.sort_seq ASC
`, tableName, strings.Join(conditions, " AND "))
// 执行查询
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
// 如果表不存在SQLite 会返回错误
if strings.Contains(err.Error(), "no such table") {
continue
}
log.Err(err).Msgf("从数据库 %s 查询消息失败", dbInfo.FilePath)
continue
}
// 处理查询结果,在读取时进行过滤
for rows.Next() {
var msg model.MessageV4
err := rows.Scan(
&msg.SortSeq,
&msg.ServerID,
&msg.LocalType,
&msg.UserName,
&msg.CreateTime,
&msg.MessageContent,
&msg.PackedInfoData,
&msg.Status,
)
if err != nil {
rows.Close()
return nil, errors.ScanRowFailed(err)
}
// 将消息转换为标准格式
message := msg.Wrap(talkerItem)
// 应用sender过滤
if len(senders) > 0 {
senderMatch := false
for _, s := range senders {
if message.Sender == s {
senderMatch = true
break
}
}
if !senderMatch {
continue // 不匹配sender跳过此消息
}
}
// 应用keyword过滤
if regex != nil {
plainText := message.PlainTextContent()
if !regex.MatchString(plainText) {
continue // 不匹配keyword跳过此消息
}
}
// 通过所有过滤条件,保留此消息
filteredMessages = append(filteredMessages, message)
// 检查是否已经满足分页处理数量
if limit > 0 && len(filteredMessages) >= offset+limit {
// 已经获取了足够的消息,可以提前返回
rows.Close()
// 对所有消息按时间排序
sort.Slice(filteredMessages, func(i, j int) bool {
return filteredMessages[i].Seq < filteredMessages[j].Seq
})
// 处理分页
if offset >= len(filteredMessages) {
return []*model.Message{}, nil
}
end := offset + limit
if end > len(filteredMessages) {
end = len(filteredMessages)
}
return filteredMessages[offset:end], nil
}
}
rows.Close()
}
}
// 对所有消息按时间排序
sort.Slice(filteredMessages, func(i, j int) bool {
return filteredMessages[i].Seq < filteredMessages[j].Seq
})
// 处理分页
if limit > 0 {
if offset >= len(filteredMessages) {
return []*model.Message{}, nil
}
end := offset + limit
if end > len(filteredMessages) {
end = len(filteredMessages)
}
return filteredMessages[offset:end], nil
}
return filteredMessages, nil
}
// 联系人
func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset int) ([]*model.Contact, error) {
var query string
var args []interface{}
if key != "" {
// 按照关键字查询
query = `SELECT username, local_type, alias, remark, nick_name
FROM contact
WHERE username = ? OR alias = ? OR remark = ? OR nick_name = ?`
args = []interface{}{key, key, key, key}
} else {
// 查询所有联系人
query = `SELECT username, local_type, alias, remark, nick_name FROM contact`
}
// 添加排序、分页
query += ` ORDER BY username`
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
if offset > 0 {
query += fmt.Sprintf(" OFFSET %d", offset)
}
}
// 执行查询
db, err := ds.dbm.GetDB(Contact)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
contacts := []*model.Contact{}
for rows.Next() {
var contactV4 model.ContactV4
err := rows.Scan(
&contactV4.UserName,
&contactV4.LocalType,
&contactV4.Alias,
&contactV4.Remark,
&contactV4.NickName,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
contacts = append(contacts, contactV4.Wrap())
}
return contacts, nil
}
// 群聊
func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offset int) ([]*model.ChatRoom, error) {
var query string
var args []interface{}
// 执行查询
db, err := ds.dbm.GetDB(Contact)
if err != nil {
return nil, err
}
if key != "" {
// 按照关键字查询
query = `SELECT username, owner, ext_buffer FROM chat_room WHERE username = ?`
args = []interface{}{key}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
chatRooms := []*model.ChatRoom{}
for rows.Next() {
var chatRoomV4 model.ChatRoomV4
err := rows.Scan(
&chatRoomV4.UserName,
&chatRoomV4.Owner,
&chatRoomV4.ExtBuffer,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
chatRooms = append(chatRooms, chatRoomV4.Wrap())
}
// 如果没有找到群聊,尝试通过联系人查找
if len(chatRooms) == 0 {
contacts, err := ds.GetContacts(ctx, key, 1, 0)
if err == nil && len(contacts) > 0 && strings.HasSuffix(contacts[0].UserName, "@chatroom") {
// 再次尝试通过用户名查找群聊
rows, err := db.QueryContext(ctx,
`SELECT username, owner, ext_buffer FROM chat_room WHERE username = ?`,
contacts[0].UserName)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
for rows.Next() {
var chatRoomV4 model.ChatRoomV4
err := rows.Scan(
&chatRoomV4.UserName,
&chatRoomV4.Owner,
&chatRoomV4.ExtBuffer,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
chatRooms = append(chatRooms, chatRoomV4.Wrap())
}
// 如果群聊记录不存在,但联系人记录存在,创建一个模拟的群聊对象
if len(chatRooms) == 0 {
chatRooms = append(chatRooms, &model.ChatRoom{
Name: contacts[0].UserName,
Users: make([]model.ChatRoomUser, 0),
User2DisplayName: make(map[string]string),
})
}
}
}
return chatRooms, nil
} else {
// 查询所有群聊
query = `SELECT username, owner, ext_buffer FROM chat_room`
// 添加排序、分页
query += ` ORDER BY username`
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
if offset > 0 {
query += fmt.Sprintf(" OFFSET %d", offset)
}
}
// 执行查询
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
chatRooms := []*model.ChatRoom{}
for rows.Next() {
var chatRoomV4 model.ChatRoomV4
err := rows.Scan(
&chatRoomV4.UserName,
&chatRoomV4.Owner,
&chatRoomV4.ExtBuffer,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
chatRooms = append(chatRooms, chatRoomV4.Wrap())
}
return chatRooms, nil
}
}
// 最近会话
func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset int) ([]*model.Session, error) {
var query string
var args []interface{}
if key != "" {
// 按照关键字查询
query = `SELECT username, summary, last_timestamp, last_msg_sender, last_sender_display_name
FROM SessionTable
WHERE username = ? OR last_sender_display_name = ?
ORDER BY sort_timestamp DESC`
args = []interface{}{key, key}
} else {
// 查询所有会话
query = `SELECT username, summary, last_timestamp, last_msg_sender, last_sender_display_name
FROM SessionTable
ORDER BY sort_timestamp DESC`
}
// 添加分页
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
if offset > 0 {
query += fmt.Sprintf(" OFFSET %d", offset)
}
}
// 执行查询
db, err := ds.dbm.GetDB(Session)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
sessions := []*model.Session{}
for rows.Next() {
var sessionV4 model.SessionV4
err := rows.Scan(
&sessionV4.Username,
&sessionV4.Summary,
&sessionV4.LastTimestamp,
&sessionV4.LastMsgSender,
&sessionV4.LastSenderDisplayName,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
sessions = append(sessions, sessionV4.Wrap())
}
return sessions, nil
}
func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (*model.Media, error) {
if key == "" {
return nil, errors.ErrKeyEmpty
}
var table string
switch _type {
case "image":
table = "image_hardlink_info_v3"
case "video":
table = "video_hardlink_info_v3"
case "file":
table = "file_hardlink_info_v3"
case "voice":
return ds.GetVoice(ctx, key)
default:
return nil, errors.MediaTypeUnsupported(_type)
}
query := fmt.Sprintf(`
SELECT
f.md5,
f.file_name,
f.file_size,
f.modify_time,
IFNULL(d1.username,""),
IFNULL(d2.username,"")
FROM
%s f
LEFT JOIN
dir2id d1 ON d1.rowid = f.dir1
LEFT JOIN
dir2id d2 ON d2.rowid = f.dir2
`, table)
query += " WHERE f.md5 = ? OR f.file_name LIKE ? || '%'"
args := []interface{}{key, key}
// 执行查询
db, err := ds.dbm.GetDB(Media)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
var media *model.Media
for rows.Next() {
var mediaV4 model.MediaV4
err := rows.Scan(
&mediaV4.Key,
&mediaV4.Name,
&mediaV4.Size,
&mediaV4.ModifyTime,
&mediaV4.Dir1,
&mediaV4.Dir2,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
mediaV4.Type = _type
media = mediaV4.Wrap()
// 跳过缩略图
if _type == "image" && !strings.Contains(media.Name, "_t") {
break
}
}
if media == nil {
return nil, errors.ErrMediaNotFound
}
return media, nil
}
func (ds *DataSource) GetVoice(ctx context.Context, key string) (*model.Media, error) {
if key == "" {
return nil, errors.ErrKeyEmpty
}
query := `
SELECT voice_data
FROM VoiceInfo
WHERE svr_id = ?
`
args := []interface{}{key}
dbs, err := ds.dbm.GetDBs(Voice)
if err != nil {
return nil, errors.DBConnectFailed("", err)
}
for _, db := range dbs {
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
for rows.Next() {
var voiceData []byte
err := rows.Scan(
&voiceData,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
if len(voiceData) > 0 {
return &model.Media{
Type: "voice",
Key: key,
Data: voiceData,
}, nil
}
}
}
return nil, errors.ErrMediaNotFound
}
func (ds *DataSource) Close() error {
return ds.dbm.Close()
}