support voice message (#31)

This commit is contained in:
Sarv
2025-04-12 03:10:32 +08:00
committed by GitHub
parent ba3563ad4e
commit f2aa923e99
15 changed files with 338 additions and 137 deletions

View File

@@ -23,6 +23,7 @@ const (
ContactFilePattern = "^contact\\.db$"
SessionFilePattern = "^session\\.db$"
MediaFilePattern = "^hardlink\\.db$"
VoiceFilePattern = "^media_([0-9]?[0-9])?\\.db$"
)
// MessageDBInfo 存储消息数据库的信息
@@ -38,6 +39,7 @@ type DataSource struct {
contactDb *sql.DB
sessionDb *sql.DB
mediaDb *sql.DB
voiceDb []*sql.DB
// 消息数据库信息
messageFiles []MessageDBInfo
@@ -47,6 +49,7 @@ func New(path string) (*DataSource, error) {
ds := &DataSource{
path: path,
messageDbs: make(map[string]*sql.DB),
voiceDb: make([]*sql.DB, 0),
messageFiles: make([]MessageDBInfo, 0),
}
@@ -62,6 +65,9 @@ func New(path string) (*DataSource, error) {
if err := ds.initMediaDb(path); err != nil {
return nil, errors.DBInitFailed(err)
}
if err := ds.initVoiceDb(path); err != nil {
return nil, errors.DBInitFailed(err)
}
return ds, nil
}
@@ -173,6 +179,24 @@ func (ds *DataSource) initMediaDb(path string) error {
return nil
}
func (ds *DataSource) initVoiceDb(path string) error {
files, err := util.FindFilesWithPatterns(path, VoiceFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, VoiceFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, VoiceFilePattern, nil)
}
for _, file := range files {
db, err := sql.Open("sqlite3", file)
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
ds.voiceDb = append(ds.voiceDb, db)
}
return nil
}
// getDBInfosForTimeRange 获取时间范围内的数据库信息
func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []MessageDBInfo {
var dbs []MessageDBInfo
@@ -188,6 +212,7 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
if talker == "" {
return nil, errors.ErrTalkerEmpty
}
log.Debug().Msg(talker)
// 找到时间范围内的数据库文件
dbInfos := ds.getDBInfosForTimeRange(startTime, endTime)
@@ -215,7 +240,7 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
continue
}
messages, err := ds.getMessagesFromDB(ctx, db, dbInfo, startTime, endTime, talker)
messages, err := ds.getMessagesFromDB(ctx, db, startTime, endTime, talker)
if err != nil {
log.Err(err).Msgf("从数据库 %s 获取消息失败", dbInfo.FilePath)
continue
@@ -260,12 +285,26 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD
talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:])
tableName := "Msg_" + talkerMd5
// 检查表是否存在
var exists bool
err := db.QueryRowContext(ctx,
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
tableName).Scan(&exists)
if err != nil {
if err == sql.ErrNoRows {
// 表不存在,返回空结果
return []*model.Message{}, nil
}
return nil, errors.QueryFailed("", err)
}
// 构建查询条件
conditions := []string{"create_time >= ? AND create_time <= ?"}
args := []interface{}{startTime.Unix(), endTime.Unix()}
query := fmt.Sprintf(`
SELECT m.sort_seq, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status
SELECT m.sort_seq, m.server_id, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status
FROM %s m
LEFT JOIN Name2Id n ON m.real_sender_id = n.rowid
WHERE %s
@@ -293,6 +332,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD
var msg model.MessageV4
err := rows.Scan(
&msg.SortSeq,
&msg.ServerID,
&msg.LocalType,
&msg.UserName,
&msg.CreateTime,
@@ -311,7 +351,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD
}
// getMessagesFromDB 从数据库获取消息
func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string) ([]*model.Message, error) {
func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, startTime, endTime time.Time, talker string) ([]*model.Message, error) {
// 构建表名
_talkerMd5Bytes := md5.Sum([]byte(talker))
talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:])
@@ -336,7 +376,7 @@ func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo
args := []interface{}{startTime.Unix(), endTime.Unix()}
query := fmt.Sprintf(`
SELECT m.sort_seq, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status
SELECT m.sort_seq, m.server_id, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status
FROM %s m
LEFT JOIN Name2Id n ON m.real_sender_id = n.rowid
WHERE %s
@@ -361,6 +401,7 @@ func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo
var msg model.MessageV4
err := rows.Scan(
&msg.SortSeq,
&msg.ServerID,
&msg.LocalType,
&msg.UserName,
&msg.CreateTime,
@@ -605,10 +646,6 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (*
return nil, errors.ErrKeyEmpty
}
if len(key) != 32 {
return nil, errors.ErrKeyLengthMust32
}
var table string
switch _type {
case "image":
@@ -617,6 +654,8 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (*
table = "video_hardlink_info_v3"
case "file":
table = "file_hardlink_info_v3"
case "voice":
return ds.GetVoice(ctx, key)
default:
return nil, errors.MediaTypeUnsupported(_type)
}
@@ -675,6 +714,46 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (*
return media, nil
}
func (ds *DataSource) GetVoice(ctx context.Context, key string) (*model.Media, error) {
if key == "" {
return nil, errors.ErrKeyEmpty
}
query := `
SELECT voice_data
FROM VoiceInfo
WHERE svr_id = ?
`
args := []interface{}{key}
for _, db := range ds.voiceDb {
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
for rows.Next() {
var voiceData []byte
err := rows.Scan(
&voiceData,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
if len(voiceData) > 0 {
return &model.Media{
Type: "voice",
Key: key,
Data: voiceData,
}, nil
}
}
}
return nil, errors.ErrMediaNotFound
}
func (ds *DataSource) Close() error {
var errs []error

View File

@@ -23,6 +23,7 @@ const (
ImageFilePattern = "^HardLinkImage\\.db$"
VideoFilePattern = "^HardLinkVideo\\.db$"
FileFilePattern = "^HardLinkFile\\.db$"
VoiceFilePattern = "^MediaMSG([0-9])?\\.db$"
)
// MessageDBInfo 保存消息数据库的信息
@@ -46,6 +47,7 @@ type DataSource struct {
imageDb *sql.DB
videoDb *sql.DB
fileDb *sql.DB
voiceDb []*sql.DB
}
// New 创建一个新的 WindowsV3DataSource
@@ -53,6 +55,7 @@ func New(path string) (*DataSource, error) {
ds := &DataSource{
messageFiles: make([]MessageDBInfo, 0),
messageDbs: make(map[string]*sql.DB),
voiceDb: make([]*sql.DB, 0),
}
// 初始化消息数据库
@@ -69,6 +72,10 @@ func New(path string) (*DataSource, error) {
return nil, errors.DBInitFailed(err)
}
if err := ds.initVoiceDb(path); err != nil {
return nil, errors.DBInitFailed(err)
}
return ds, nil
}
@@ -238,6 +245,24 @@ func (ds *DataSource) initMediaDb(path string) error {
return nil
}
func (ds *DataSource) initVoiceDb(path string) error {
files, err := util.FindFilesWithPatterns(path, VoiceFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, VoiceFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, VoiceFilePattern, nil)
}
for _, file := range files {
db, err := sql.Open("sqlite3", file)
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
ds.voiceDb = append(ds.voiceDb, db)
}
return nil
}
// getDBInfosForTimeRange 获取时间范围内的数据库信息
func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []MessageDBInfo {
var dbs []MessageDBInfo
@@ -293,7 +318,7 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
}
query := fmt.Sprintf(`
SELECT Sequence, CreateTime, StrTalker, IsSender,
SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender,
Type, SubType, StrContent, CompressContent, BytesExtra
FROM MSG
WHERE %s
@@ -314,6 +339,7 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
var bytesExtra []byte
err := rows.Scan(
&msg.MsgSvrID,
&msg.Sequence,
&msg.CreateTime,
&msg.StrTalker,
@@ -377,7 +403,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD
}
}
query := fmt.Sprintf(`
SELECT Sequence, CreateTime, StrTalker, IsSender,
SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender,
Type, SubType, StrContent, CompressContent, BytesExtra
FROM MSG
WHERE %s
@@ -406,6 +432,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD
var compressContent []byte
var bytesExtra []byte
err := rows.Scan(
&msg.MsgSvrID,
&msg.Sequence,
&msg.CreateTime,
&msg.StrTalker,
@@ -652,6 +679,10 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (*
return nil, errors.ErrKeyEmpty
}
if _type == "voice" {
return ds.GetVoice(ctx, key)
}
md5key, err := hex.DecodeString(key)
if err != nil {
return nil, errors.DecodeKeyFailed(err)
@@ -725,6 +756,46 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (*
return media, nil
}
func (ds *DataSource) GetVoice(ctx context.Context, key string) (*model.Media, error) {
if key == "" {
return nil, errors.ErrKeyEmpty
}
query := `
SELECT Buf
FROM Media
WHERE Reserved0 = ?
`
args := []interface{}{key}
for _, db := range ds.voiceDb {
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
for rows.Next() {
var voiceData []byte
err := rows.Scan(
&voiceData,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
if len(voiceData) > 0 {
return &model.Media{
Type: "voice",
Key: key,
Data: voiceData,
}, nil
}
}
}
return nil, errors.ErrMediaNotFound
}
// Close 实现 DataSource 接口的 Close 方法
func (ds *DataSource) Close() error {
var errs []error