diff --git a/internal/chatlog/database/service.go b/internal/chatlog/database/service.go index 2d0ac23..875f202 100644 --- a/internal/chatlog/database/service.go +++ b/internal/chatlog/database/service.go @@ -40,8 +40,8 @@ func (s *Service) GetDB() *wechatdb.DB { return s.db } -func (s *Service) GetMessages(start, end time.Time, talker string, limit, offset int) ([]*model.Message, error) { - return s.db.GetMessages(start, end, talker, limit, offset) +func (s *Service) GetMessages(start, end time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) { + return s.db.GetMessages(start, end, talker, sender, keyword, limit, offset) } func (s *Service) GetContacts(key string, limit, offset int) (*wechatdb.GetContactsResp, error) { diff --git a/internal/chatlog/http/route.go b/internal/chatlog/http/route.go index e78cfe6..d32d103 100644 --- a/internal/chatlog/http/route.go +++ b/internal/chatlog/http/route.go @@ -75,11 +75,13 @@ func (s *Service) NoRoute(c *gin.Context) { func (s *Service) GetChatlog(c *gin.Context) { q := struct { - Time string `form:"time"` - Talker string `form:"talker"` - Limit int `form:"limit"` - Offset int `form:"offset"` - Format string `form:"format"` + Time string `form:"time"` + Talker string `form:"talker"` + Sender string `form:"sender"` + Keyword string `form:"keyword"` + Limit int `form:"limit"` + Offset int `form:"offset"` + Format string `form:"format"` }{} if err := c.BindQuery(&q); err != nil { @@ -100,7 +102,7 @@ func (s *Service) GetChatlog(c *gin.Context) { q.Offset = 0 } - messages, err := s.db.GetMessages(start, end, q.Talker, q.Limit, q.Offset) + messages, err := s.db.GetMessages(start, end, q.Talker, q.Sender, q.Keyword, q.Limit, q.Offset) if err != nil { errors.Err(c, err) return @@ -119,7 +121,7 @@ func (s *Service) GetChatlog(c *gin.Context) { c.Writer.Flush() for _, m := range messages { - c.Writer.WriteString(m.PlainText(len(q.Talker) == 0, c.Request.Host)) + c.Writer.WriteString(m.PlainText(strings.Contains(q.Talker, ","), util.PerfectTimeFormat(start, end), c.Request.Host)) c.Writer.WriteString("\n") c.Writer.Flush() } @@ -129,10 +131,10 @@ func (s *Service) GetChatlog(c *gin.Context) { func (s *Service) GetContacts(c *gin.Context) { q := struct { - Key string `form:"key"` - Limit int `form:"limit"` - Offset int `form:"offset"` - Format string `form:"format"` + Keyword string `form:"keyword"` + Limit int `form:"limit"` + Offset int `form:"offset"` + Format string `form:"format"` }{} if err := c.BindQuery(&q); err != nil { @@ -140,7 +142,7 @@ func (s *Service) GetContacts(c *gin.Context) { return } - list, err := s.db.GetContacts(q.Key, q.Limit, q.Offset) + list, err := s.db.GetContacts(q.Keyword, q.Limit, q.Offset) if err != nil { errors.Err(c, err) return @@ -174,10 +176,10 @@ func (s *Service) GetContacts(c *gin.Context) { func (s *Service) GetChatRooms(c *gin.Context) { q := struct { - Key string `form:"key"` - Limit int `form:"limit"` - Offset int `form:"offset"` - Format string `form:"format"` + Keyword string `form:"keyword"` + Limit int `form:"limit"` + Offset int `form:"offset"` + Format string `form:"format"` }{} if err := c.BindQuery(&q); err != nil { @@ -185,7 +187,7 @@ func (s *Service) GetChatRooms(c *gin.Context) { return } - list, err := s.db.GetChatRooms(q.Key, q.Limit, q.Offset) + list, err := s.db.GetChatRooms(q.Keyword, q.Limit, q.Offset) if err != nil { errors.Err(c, err) return @@ -218,10 +220,10 @@ func (s *Service) GetChatRooms(c *gin.Context) { func (s *Service) GetSessions(c *gin.Context) { q := struct { - Key string `form:"key"` - Limit int `form:"limit"` - Offset int `form:"offset"` - Format string `form:"format"` + Keyword string `form:"keyword"` + Limit int `form:"limit"` + Offset int `form:"offset"` + Format string `form:"format"` }{} if err := c.BindQuery(&q); err != nil { @@ -229,7 +231,7 @@ func (s *Service) GetSessions(c *gin.Context) { return } - sessions, err := s.db.GetSessions(q.Key, q.Limit, q.Offset) + sessions, err := s.db.GetSessions(q.Keyword, q.Limit, q.Offset) if err != nil { errors.Err(c, err) return diff --git a/internal/chatlog/http/static/index.htm b/internal/chatlog/http/static/index.htm index b8d05a1..7a53ed1 100644 --- a/internal/chatlog/http/static/index.htm +++ b/internal/chatlog/http/static/index.htm @@ -377,12 +377,12 @@

-
@@ -408,12 +408,12 @@

-
@@ -459,6 +459,26 @@ placeholder="wxid、群ID、备注名或昵称" /> +
+ + +
+
+ + +
- + \ No newline at end of file diff --git a/internal/chatlog/mcp/const.go b/internal/chatlog/mcp/const.go index 64ba277..21b2e6f 100644 --- a/internal/chatlog/mcp/const.go +++ b/internal/chatlog/mcp/const.go @@ -21,12 +21,12 @@ var ( InputSchema: mcp.ToolSchema{ Type: "object", Properties: mcp.M{ - "query": mcp.M{ + "keyword": mcp.M{ "type": "string", "description": "联系人的搜索关键词,可以是姓名、备注名或ID。", }, }, - Required: []string{"query"}, + Required: []string{"keyword"}, }, } @@ -36,12 +36,12 @@ var ( InputSchema: mcp.ToolSchema{ Type: "object", Properties: mcp.M{ - "query": mcp.M{ + "keyword": mcp.M{ "type": "string", "description": "群聊的搜索关键词,可以是群名称、群ID或相关描述", }, }, - Required: []string{"query"}, + Required: []string{"keyword"}, }, } @@ -55,24 +55,136 @@ var ( } ToolChatLog = mcp.Tool{ - Name: "chatlog", - Description: "查询特定时间或时间段内与特定联系人或群组的聊天记录。当用户需要回顾过去的对话内容、查找特定信息或想了解与某人/某群的历史交流时使用此工具。", + Name: "chatlog", + Description: `检索历史聊天记录,可根据时间、对话方、发送者和关键词等条件进行精确查询。当用户需要查找特定信息或想了解与某人/某群的历史交流时使用此工具。 + +【强制多步查询流程!】 +当查询特定话题或特定发送者发言时,必须严格按照以下流程使用,任何偏离都会导致错误的结果: + +步骤1: 初步定位相关消息 +- 使用keyword参数查找特定话题 +- 使用sender参数查找特定发送者的消息 +- 使用较宽时间范围初步查询 + +步骤2: 【必须执行】针对每个关键结果点分别获取上下文 +- 必须对步骤1返回的每个时间点T1, T2, T3...分别执行独立查询(时间范围接近的消息可以合并为一个查询) +- 每次独立查询必须移除keyword参数 +- 每次独立查询必须移除sender参数 +- 每次独立查询使用"Tn前后15-30分钟"的窄范围 +- 每次独立查询仅保留talker参数 + +步骤3: 【必须执行】综合分析所有上下文 +- 必须等待所有步骤2的查询结果返回后再进行分析 +- 必须综合考虑所有上下文信息后再回答用户 + +【严格执行规则!】 +- 禁止仅凭步骤1的结果直接回答用户 +- 禁止在步骤2使用过大的时间范围一次性查询所有上下文 +- 禁止跳过步骤2或步骤3 +- 必须对每个关键结果点分别执行独立的上下文查询 + +【执行示例】 +正确流程示例: +1. 步骤1: chatlog(time="2023-04-01~2023-04-30", talker="工作群", keyword="项目进度") + 返回结果: 4月5日、4月12日、4月20日有相关消息 +2. 步骤2: + - 查询1: chatlog(time="2023-04-05/09:30~2023-04-05/10:30", talker="工作群") // 注意没有keyword + - 查询2: chatlog(time="2023-04-12/14:00~2023-04-12/15:00", talker="工作群") // 注意没有keyword + - 查询3: chatlog(time="2023-04-20/16:00~2023-04-20/17:00", talker="工作群") // 注意没有keyword +3. 步骤3: 综合分析所有上下文后回答用户 + +错误流程示例: +- 仅执行步骤1后直接回答 +- 步骤2使用time="2023-04-01~2023-04-30"一次性查询 +- 步骤2仍然保留keyword或sender参数 + +【自我检查】回答用户前必须自问: +- 我是否对每个关键时间点都执行了独立的上下文查询? +- 我是否在上下文查询中移除了keyword和sender参数? +- 我是否分析了所有上下文后再回答? +- 如果上述任一问题答案为"否",则必须纠正流程 + +返回格式:"昵称(ID) 时间\n消息内容\n昵称(ID) 时间\n消息内容" +当查询多个Talker时,返回格式为:"昵称(ID)\n[TalkerName(Talker)] 时间\n消息内容" + +重要提示: +1. 当用户询问特定时间段内的聊天记录时,必须使用正确的时间格式,特别是包含小时和分钟的查询 +2. 对于"今天下午4点到5点聊了啥"这类查询,正确的时间参数格式应为"2023-04-18/16:00~2023-04-18/17:00" +3. 当用户询问具体群聊中某人的聊天记录时,使用"sender"参数 +4. 当用户询问包含特定关键词的聊天记录时,使用"keyword"参数`, InputSchema: mcp.ToolSchema{ Type: "object", Properties: mcp.M{ "time": mcp.M{ - "type": "string", - "description": "查询的时间点或时间段。可以是具体时间,例如 YYYY-MM-DD,也可以是时间段,例如 YYYY-MM-DD~YYYY-MM-DD,时间段之间用\"~\"分隔。", + "type": "string", + "description": `指定查询的时间点或时间范围,格式必须严格遵循以下规则: + +【单一时间点格式】 +- 精确到日:"2023-04-18"或"20230418" +- 精确到分钟(必须包含斜杠和冒号):"2023-04-18/14:30"或"20230418/14:30"(表示2023年4月18日14点30分) + +【时间范围格式】(使用"~"分隔起止时间) +- 日期范围:"2023-04-01~2023-04-18" +- 同一天的时间段:"2023-04-18/14:30~2023-04-18/15:45" + * 表示2023年4月18日14点30分到15点45分之间 + +【重要提示】包含小时分钟的格式必须使用斜杠和冒号:"/"和":" +正确示例:"2023-04-18/16:30"(4月18日下午4点30分) +错误示例:"2023-04-18 16:30"、"2023-04-18T16:30" + +【其他支持的格式】 +- 年份:"2023" +- 月份:"2023-04"或"202304"`, }, "talker": mcp.M{ - "type": "string", - "description": "交谈对象,可以是联系人或群聊。支持使用ID、昵称、备注名等进行查询。", + "type": "string", + "description": `指定对话方(联系人或群组) +- 可使用ID、昵称或备注名 +- 多个对话方用","分隔,如:"张三,李四,工作群" +- 【重要】这是多步查询中唯一应保留的参数`, + }, + "sender": mcp.M{ + "type": "string", + "description": `指定群聊中的发送者 +- 仅在查询群聊记录时有效 +- 多个发送者用","分隔,如:"张三,李四" +- 可使用ID、昵称或备注名 +【重要】查询特定发送者的消息时: + 1. 第一步:使用sender参数初步定位多个相关消息时间点 + 2. 后续步骤:必须移除sender参数,分别查询每个时间点前后的完整对话 + 3. 错误示例:对所有找到的消息一次性查询大范围上下文 + 4. 正确示例:对每个时间点T分别执行查询"T前后15-30分钟"(不带sender)`, + }, + "keyword": mcp.M{ + "type": "string", + "description": `搜索内容中的关键词 +- 支持正则表达式匹配 +- 【重要】查询特定话题时: + 1. 第一步:使用keyword参数初步定位多个相关消息时间点 + 2. 后续步骤:必须移除keyword参数,分别查询每个时间点前后的完整对话 + 3. 错误示例:对所有找到的关键词消息一次性查询大范围上下文 + 4. 正确示例:对每个时间点T分别执行查询"T前后15-30分钟"(不带keyword)`, }, }, Required: []string{"time", "talker"}, }, } + ToolCurrentTime = mcp.Tool{ + Name: "current_time", + Description: `获取当前系统时间,返回RFC3339格式的时间字符串(包含用户本地时区信息)。 +使用场景: +- 当用户询问"总结今日聊天记录"、"本周都聊了啥"等当前时间问题 +- 当用户提及"昨天"、"上周"、"本月"等相对时间概念,需要确定基准时间点 +- 需要执行依赖当前时间的计算(如"上个月5号我们有开会吗") +返回示例:2025-04-18T21:29:00+08:00 +注意:此工具不需要任何输入参数,直接调用即可获取当前时间。`, + InputSchema: mcp.ToolSchema{ + Type: "object", + Properties: mcp.M{}, + }, + } + ResourceRecentChat = mcp.Resource{ Name: "最近会话", URI: "session://recent", diff --git a/internal/chatlog/mcp/service.go b/internal/chatlog/mcp/service.go index 8ebff79..8f5f0a4 100644 --- a/internal/chatlog/mcp/service.go +++ b/internal/chatlog/mcp/service.go @@ -7,6 +7,7 @@ import ( "fmt" "net/url" "strings" + "time" "github.com/sjzar/chatlog/internal/chatlog/ctx" "github.com/sjzar/chatlog/internal/chatlog/database" @@ -83,6 +84,7 @@ func (s *Service) processMCP(session *mcp.Session, req *mcp.Request) { ToolChatRoom, ToolRecentChat, ToolChatLog, + ToolCurrentTime, }}) case mcp.MethodToolsCall: err = s.toolsCall(session, req) @@ -130,13 +132,13 @@ func (s *Service) toolsCall(session *mcp.Session, req *mcp.Request) error { buf := &bytes.Buffer{} switch callReq.Name { case "query_contact": - query := "" - if v, ok := callReq.Arguments["query"]; ok { - query = v.(string) + keyword := "" + if v, ok := callReq.Arguments["keyword"]; ok { + keyword = v.(string) } limit := util.MustAnyToInt(callReq.Arguments["limit"]) offset := util.MustAnyToInt(callReq.Arguments["offset"]) - list, err := s.db.GetContacts(query, limit, offset) + list, err := s.db.GetContacts(keyword, limit, offset) if err != nil { return fmt.Errorf("无法获取联系人列表: %v", err) } @@ -145,13 +147,13 @@ func (s *Service) toolsCall(session *mcp.Session, req *mcp.Request) error { buf.WriteString(fmt.Sprintf("%s,%s,%s,%s\n", contact.UserName, contact.Alias, contact.Remark, contact.NickName)) } case "query_chat_room": - query := "" - if v, ok := callReq.Arguments["query"]; ok { - query = v.(string) + keyword := "" + if v, ok := callReq.Arguments["keyword"]; ok { + keyword = v.(string) } limit := util.MustAnyToInt(callReq.Arguments["limit"]) offset := util.MustAnyToInt(callReq.Arguments["offset"]) - list, err := s.db.GetChatRooms(query, limit, offset) + list, err := s.db.GetChatRooms(keyword, limit, offset) if err != nil { return fmt.Errorf("无法获取群聊列表: %v", err) } @@ -160,13 +162,13 @@ func (s *Service) toolsCall(session *mcp.Session, req *mcp.Request) error { buf.WriteString(fmt.Sprintf("%s,%s,%s,%s,%d\n", chatRoom.Name, chatRoom.Remark, chatRoom.NickName, chatRoom.Owner, len(chatRoom.Users))) } case "query_recent_chat": - query := "" - if v, ok := callReq.Arguments["query"]; ok { - query = v.(string) + keyword := "" + if v, ok := callReq.Arguments["keyword"]; ok { + keyword = v.(string) } limit := util.MustAnyToInt(callReq.Arguments["limit"]) offset := util.MustAnyToInt(callReq.Arguments["offset"]) - data, err := s.db.GetSessions(query, limit, offset) + data, err := s.db.GetSessions(keyword, limit, offset) if err != nil { return fmt.Errorf("无法获取会话列表: %v", err) } @@ -190,16 +192,29 @@ func (s *Service) toolsCall(session *mcp.Session, req *mcp.Request) error { if v, ok := callReq.Arguments["talker"]; ok { talker = v.(string) } + sender := "" + if v, ok := callReq.Arguments["sender"]; ok { + sender = v.(string) + } + keyword := "" + if v, ok := callReq.Arguments["keyword"]; ok { + keyword = v.(string) + } limit := util.MustAnyToInt(callReq.Arguments["limit"]) offset := util.MustAnyToInt(callReq.Arguments["offset"]) - messages, err := s.db.GetMessages(start, end, talker, limit, offset) + messages, err := s.db.GetMessages(start, end, talker, sender, keyword, limit, offset) if err != nil { return fmt.Errorf("无法获取聊天记录: %v", err) } + if len(messages) == 0 { + buf.WriteString("未找到符合查询条件的聊天记录") + } for _, m := range messages { - buf.WriteString(m.PlainText(len(talker) == 0, "")) + buf.WriteString(m.PlainText(strings.Contains(talker, ","), util.PerfectTimeFormat(start, end), "")) buf.WriteString("\n") } + case "current_time": + buf.WriteString(time.Now().Local().Format(time.RFC3339)) default: return fmt.Errorf("未支持的工具: %s", callReq.Name) } @@ -228,7 +243,6 @@ func (s *Service) resourcesRead(session *mcp.Session, req *mcp.Request) error { buf := &bytes.Buffer{} switch u.Scheme { case "contact": - list, err := s.db.GetContacts(u.Host, 0, 0) if err != nil { return fmt.Errorf("无法获取联系人列表: %v", err) @@ -262,12 +276,15 @@ func (s *Service) resourcesRead(session *mcp.Session, req *mcp.Request) error { } limit := util.MustAnyToInt(u.Query().Get("limit")) offset := util.MustAnyToInt(u.Query().Get("offset")) - messages, err := s.db.GetMessages(start, end, u.Host, limit, offset) + messages, err := s.db.GetMessages(start, end, u.Host, "", "", limit, offset) if err != nil { return fmt.Errorf("无法获取聊天记录: %v", err) } + if len(messages) == 0 { + buf.WriteString("未找到符合查询条件的聊天记录") + } for _, m := range messages { - buf.WriteString(m.PlainText(len(u.Host) == 0, "")) + buf.WriteString(m.PlainText(strings.Contains(u.Host, ","), util.PerfectTimeFormat(start, end), "")) buf.WriteString("\n") } default: diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 68d2b3d..bdb45c5 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -65,7 +65,7 @@ func Newf(cause error, code int, format string, args ...interface{}) *Error { return &Error{ Message: fmt.Sprintf(format, args...), Cause: cause, - Code: http.StatusInternalServerError, + Code: code, } } diff --git a/internal/errors/middleware.go b/internal/errors/middleware.go index 2965923..67313ff 100644 --- a/internal/errors/middleware.go +++ b/internal/errors/middleware.go @@ -2,6 +2,7 @@ package errors import ( "net/http" + "runtime/debug" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -50,7 +51,7 @@ func RecoveryMiddleware() gin.HandlerFunc { } // 记录错误日志 - log.Err(err).Msg("PANIC RECOVERED") + log.Err(err).Msgf("PANIC RECOVERED\n%s", string(debug.Stack())) // 返回 500 错误 c.JSON(http.StatusInternalServerError, err) diff --git a/internal/model/message.go b/internal/model/message.go index b8b6b10..eb783b8 100644 --- a/internal/model/message.go +++ b/internal/model/message.go @@ -183,7 +183,11 @@ func (m *Message) SetContent(key string, value interface{}) { m.Contents[key] = value } -func (m *Message) PlainText(showChatRoom bool, host string) string { +func (m *Message) PlainText(showChatRoom bool, timeFormat string, host string) string { + + if timeFormat == "" { + timeFormat = "01-02 15:04:05" + } m.SetContent("host", host) @@ -216,7 +220,7 @@ func (m *Message) PlainText(showChatRoom bool, host string) string { buf.WriteString("] ") } - buf.WriteString(m.Time.Format("2006-01-02 15:04:05")) + buf.WriteString(m.Time.Format(timeFormat)) buf.WriteString("\n") buf.WriteString(m.PlainTextContent()) @@ -262,7 +266,11 @@ func (m *Message) PlainTextContent() string { if !ok { return "[合并转发]" } - return recordInfo.String("", m.Contents["host"].(string)) + host := "" + if m.Contents["host"] != nil { + host = m.Contents["host"].(string) + } + return recordInfo.String("", host) case 33, 36: if m.Contents["title"] == "" { return "[小程序]" @@ -290,7 +298,11 @@ func (m *Message) PlainTextContent() string { return "> [引用]\n" + m.Content } buf := strings.Builder{} - referContent := refer.PlainText(false, m.Contents["host"].(string)) + host := "" + if m.Contents["host"] != nil { + host = m.Contents["host"].(string) + } + referContent := refer.PlainText(false, "", host) for _, line := range strings.Split(referContent, "\n") { if line == "" { continue diff --git a/internal/wechatdb/datasource/darwinv3/datasource.go b/internal/wechatdb/datasource/darwinv3/datasource.go index 26f9cf3..f8f97ca 100644 --- a/internal/wechatdb/datasource/darwinv3/datasource.go +++ b/internal/wechatdb/datasource/darwinv3/datasource.go @@ -5,6 +5,8 @@ import ( "crypto/md5" "encoding/hex" "fmt" + "regexp" + "sort" "strings" "time" @@ -15,6 +17,7 @@ import ( "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm" + "github.com/sjzar/chatlog/pkg/util" ) const ( @@ -188,70 +191,162 @@ func (ds *DataSource) initChatRoomDb() error { return nil } -// GetMessages 实现获取消息的方法 -func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { - // 在 darwinv3 中,每个联系人/群聊的消息存储在单独的表中,表名为 Chat_md5(talker) - // 首先需要找到对应的表名 +func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) { if talker == "" { return nil, errors.ErrTalkerEmpty } - _talkerMd5Bytes := md5.Sum([]byte(talker)) - talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:]) - dbPath, ok := ds.talkerDBMap[talkerMd5] - if !ok { - return nil, errors.TalkerNotFound(talker) + // 解析talker参数,支持多个talker(以英文逗号分隔) + talkers := util.Str2List(talker, ",") + if len(talkers) == 0 { + return nil, errors.ErrTalkerEmpty } - db, err := ds.dbm.OpenDB(dbPath) - if err != nil { - return nil, err - } - tableName := fmt.Sprintf("Chat_%s", talkerMd5) - // 构建查询条件 - query := fmt.Sprintf(` - SELECT msgCreateTime, msgContent, messageType, mesDes - FROM %s - WHERE msgCreateTime >= ? AND msgCreateTime <= ? - ORDER BY msgCreateTime ASC - `, tableName) + // 解析sender参数,支持多个发送者(以英文逗号分隔) + senders := util.Str2List(sender, ",") - if limit > 0 { - query += fmt.Sprintf(" LIMIT %d", limit) - - if offset > 0 { - query += fmt.Sprintf(" OFFSET %d", offset) + // 预编译正则表达式(如果有keyword) + var regex *regexp.Regexp + if keyword != "" { + var err error + regex, err = regexp.Compile(keyword) + if err != nil { + return nil, errors.QueryFailed("invalid regex pattern", err) } } - // 执行查询 - rows, err := db.QueryContext(ctx, query, startTime.Unix(), endTime.Unix()) - if err != nil { - return nil, errors.QueryFailed(query, err) - } - defer rows.Close() + // 从每个相关数据库中查询消息,并在读取时进行过滤 + filteredMessages := []*model.Message{} - // 处理查询结果 - messages := []*model.Message{} - for rows.Next() { - var msg model.MessageDarwinV3 - err := rows.Scan( - &msg.MsgCreateTime, - &msg.MsgContent, - &msg.MessageType, - &msg.MesDes, - ) - if err != nil { - log.Err(err).Msgf("扫描消息行失败") + // 对每个talker进行查询 + for _, talkerItem := range talkers { + // 检查上下文是否已取消 + if err := ctx.Err(); err != nil { + return nil, err + } + + // 在 darwinv3 中,需要先找到对应的数据库 + _talkerMd5Bytes := md5.Sum([]byte(talkerItem)) + talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:]) + dbPath, ok := ds.talkerDBMap[talkerMd5] + if !ok { + // 如果找不到对应的数据库,跳过此talker continue } - // 将消息包装为通用模型 - message := msg.Wrap(talker) - messages = append(messages, message) + db, err := ds.dbm.OpenDB(dbPath) + if err != nil { + log.Error().Msgf("数据库 %s 未打开", dbPath) + continue + } + + tableName := fmt.Sprintf("Chat_%s", talkerMd5) + + // 构建查询条件 + query := fmt.Sprintf(` + SELECT msgCreateTime, msgContent, messageType, mesDes + FROM %s + WHERE msgCreateTime >= ? AND msgCreateTime <= ? + ORDER BY msgCreateTime ASC + `, tableName) + + // 执行查询 + rows, err := db.QueryContext(ctx, query, startTime.Unix(), endTime.Unix()) + if err != nil { + // 如果表不存在,跳过此talker + if strings.Contains(err.Error(), "no such table") { + continue + } + log.Err(err).Msgf("从数据库 %s 查询消息失败", dbPath) + continue + } + + // 处理查询结果,在读取时进行过滤 + for rows.Next() { + var msg model.MessageDarwinV3 + err := rows.Scan( + &msg.MsgCreateTime, + &msg.MsgContent, + &msg.MessageType, + &msg.MesDes, + ) + if err != nil { + rows.Close() + log.Err(err).Msgf("扫描消息行失败") + continue + } + + // 将消息包装为通用模型 + message := msg.Wrap(talkerItem) + + // 应用sender过滤 + if len(senders) > 0 { + senderMatch := false + for _, s := range senders { + if message.Sender == s { + senderMatch = true + break + } + } + if !senderMatch { + continue // 不匹配sender,跳过此消息 + } + } + + // 应用keyword过滤 + if regex != nil { + plainText := message.PlainTextContent() + if !regex.MatchString(plainText) { + continue // 不匹配keyword,跳过此消息 + } + } + + // 通过所有过滤条件,保留此消息 + filteredMessages = append(filteredMessages, message) + + // 检查是否已经满足分页处理数量 + if limit > 0 && len(filteredMessages) >= offset+limit { + // 已经获取了足够的消息,可以提前返回 + rows.Close() + + // 对所有消息按时间排序 + sort.Slice(filteredMessages, func(i, j int) bool { + return filteredMessages[i].Seq < filteredMessages[j].Seq + }) + + // 处理分页 + if offset >= len(filteredMessages) { + return []*model.Message{}, nil + } + end := offset + limit + if end > len(filteredMessages) { + end = len(filteredMessages) + } + return filteredMessages[offset:end], nil + } + } + rows.Close() } - return messages, nil + // 对所有消息按时间排序 + // FIXME 不同 talker 需要使用 Time 排序 + sort.Slice(filteredMessages, func(i, j int) bool { + return filteredMessages[i].Time.Before(filteredMessages[j].Time) + }) + + // 处理分页 + if limit > 0 { + if offset >= len(filteredMessages) { + return []*model.Message{}, nil + } + end := offset + limit + if end > len(filteredMessages) { + end = len(filteredMessages) + } + return filteredMessages[offset:end], nil + } + + return filteredMessages, nil } // 从表名中提取 talker diff --git a/internal/wechatdb/datasource/datasource.go b/internal/wechatdb/datasource/datasource.go index 56d68bd..6358105 100644 --- a/internal/wechatdb/datasource/datasource.go +++ b/internal/wechatdb/datasource/datasource.go @@ -16,7 +16,7 @@ import ( type DataSource interface { // 消息 - GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) + GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) // 联系人 GetContacts(ctx context.Context, key string, limit, offset int) ([]*model.Contact, error) diff --git a/internal/wechatdb/datasource/v4/datasource.go b/internal/wechatdb/datasource/v4/datasource.go index b837047..8dc69d8 100644 --- a/internal/wechatdb/datasource/v4/datasource.go +++ b/internal/wechatdb/datasource/v4/datasource.go @@ -6,6 +6,7 @@ import ( "database/sql" "encoding/hex" "fmt" + "regexp" "sort" "strings" "time" @@ -17,6 +18,7 @@ import ( "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm" + "github.com/sjzar/chatlog/pkg/util" ) const ( @@ -175,11 +177,16 @@ func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []Mes return dbs } -func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { +func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) { if talker == "" { return nil, errors.ErrTalkerEmpty } - log.Debug().Msg(talker) + + // 解析talker参数,支持多个talker(以英文逗号分隔) + talkers := util.Str2List(talker, ",") + if len(talkers) == 0 { + return nil, errors.ErrTalkerEmpty + } // 找到时间范围内的数据库文件 dbInfos := ds.getDBInfosForTimeRange(startTime, endTime) @@ -187,13 +194,21 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T return nil, errors.TimeRangeNotFound(startTime, endTime) } - if len(dbInfos) == 1 { - // LIMIT 和 OFFSET 逻辑在单文件情况下可以直接在 SQL 里处理 - return ds.getMessagesSingleFile(ctx, dbInfos[0], startTime, endTime, talker, limit, offset) + // 解析sender参数,支持多个发送者(以英文逗号分隔) + senders := util.Str2List(sender, ",") + + // 预编译正则表达式(如果有keyword) + var regex *regexp.Regexp + if keyword != "" { + var err error + regex, err = regexp.Compile(keyword) + if err != nil { + return nil, errors.QueryFailed("invalid regex pattern", err) + } } - // 从每个相关数据库中查询消息 - totalMessages := []*model.Message{} + // 从每个相关数据库中查询消息,并在读取时进行过滤 + filteredMessages := []*model.Message{} for _, dbInfo := range dbInfos { // 检查上下文是否已取消 @@ -207,183 +222,141 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T continue } - messages, err := ds.getMessagesFromDB(ctx, db, startTime, endTime, talker) - if err != nil { - log.Err(err).Msgf("从数据库 %s 获取消息失败", dbInfo.FilePath) - continue - } + // 对每个talker进行查询 + for _, talkerItem := range talkers { + // 构建表名 + _talkerMd5Bytes := md5.Sum([]byte(talkerItem)) + talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:]) + tableName := "Msg_" + talkerMd5 - totalMessages = append(totalMessages, messages...) + // 检查表是否存在 + var exists bool + err = db.QueryRowContext(ctx, + "SELECT 1 FROM sqlite_master WHERE type='table' AND name=?", + tableName).Scan(&exists) - if limit+offset > 0 && len(totalMessages) >= limit+offset { - break + if err != nil { + if err == sql.ErrNoRows { + // 表不存在,继续下一个talker + continue + } + return nil, errors.QueryFailed("", err) + } + + // 构建查询条件 + conditions := []string{"create_time >= ? AND create_time <= ?"} + args := []interface{}{startTime.Unix(), endTime.Unix()} + log.Debug().Msgf("Table name: %s", tableName) + log.Debug().Msgf("Start time: %d, End time: %d", startTime.Unix(), endTime.Unix()) + + query := fmt.Sprintf(` + SELECT m.sort_seq, m.server_id, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status + FROM %s m + LEFT JOIN Name2Id n ON m.real_sender_id = n.rowid + WHERE %s + ORDER BY m.sort_seq ASC + `, tableName, strings.Join(conditions, " AND ")) + + // 执行查询 + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + // 如果表不存在,SQLite 会返回错误 + if strings.Contains(err.Error(), "no such table") { + continue + } + log.Err(err).Msgf("从数据库 %s 查询消息失败", dbInfo.FilePath) + continue + } + + // 处理查询结果,在读取时进行过滤 + for rows.Next() { + var msg model.MessageV4 + err := rows.Scan( + &msg.SortSeq, + &msg.ServerID, + &msg.LocalType, + &msg.UserName, + &msg.CreateTime, + &msg.MessageContent, + &msg.PackedInfoData, + &msg.Status, + ) + if err != nil { + rows.Close() + return nil, errors.ScanRowFailed(err) + } + + // 将消息转换为标准格式 + message := msg.Wrap(talkerItem) + + // 应用sender过滤 + if len(senders) > 0 { + senderMatch := false + for _, s := range senders { + if message.Sender == s { + senderMatch = true + break + } + } + if !senderMatch { + continue // 不匹配sender,跳过此消息 + } + } + + // 应用keyword过滤 + if regex != nil { + plainText := message.PlainTextContent() + if !regex.MatchString(plainText) { + continue // 不匹配keyword,跳过此消息 + } + } + + // 通过所有过滤条件,保留此消息 + filteredMessages = append(filteredMessages, message) + + // 检查是否已经满足分页处理数量 + if limit > 0 && len(filteredMessages) >= offset+limit { + // 已经获取了足够的消息,可以提前返回 + rows.Close() + + // 对所有消息按时间排序 + sort.Slice(filteredMessages, func(i, j int) bool { + return filteredMessages[i].Seq < filteredMessages[j].Seq + }) + + // 处理分页 + if offset >= len(filteredMessages) { + return []*model.Message{}, nil + } + end := offset + limit + if end > len(filteredMessages) { + end = len(filteredMessages) + } + return filteredMessages[offset:end], nil + } + } + rows.Close() } } // 对所有消息按时间排序 - sort.Slice(totalMessages, func(i, j int) bool { - return totalMessages[i].Seq < totalMessages[j].Seq + sort.Slice(filteredMessages, func(i, j int) bool { + return filteredMessages[i].Seq < filteredMessages[j].Seq }) // 处理分页 if limit > 0 { - if offset >= len(totalMessages) { + if offset >= len(filteredMessages) { return []*model.Message{}, nil } end := offset + limit - if end > len(totalMessages) { - end = len(totalMessages) + if end > len(filteredMessages) { + end = len(filteredMessages) } - return totalMessages[offset:end], nil + return filteredMessages[offset:end], nil } - return totalMessages, nil -} - -// getMessagesSingleFile 从单个数据库文件获取消息 -func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { - db, err := ds.dbm.OpenDB(dbInfo.FilePath) - if err != nil { - return nil, errors.DBConnectFailed(dbInfo.FilePath, nil) - } - - // 构建表名 - _talkerMd5Bytes := md5.Sum([]byte(talker)) - talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:]) - tableName := "Msg_" + talkerMd5 - - // 检查表是否存在 - var exists bool - err = db.QueryRowContext(ctx, - "SELECT 1 FROM sqlite_master WHERE type='table' AND name=?", - tableName).Scan(&exists) - - if err != nil { - if err == sql.ErrNoRows { - // 表不存在,返回空结果 - return []*model.Message{}, nil - } - return nil, errors.QueryFailed("", err) - } - - // 构建查询条件 - conditions := []string{"create_time >= ? AND create_time <= ?"} - args := []interface{}{startTime.Unix(), endTime.Unix()} - - query := fmt.Sprintf(` - SELECT m.sort_seq, m.server_id, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status - FROM %s m - LEFT JOIN Name2Id n ON m.real_sender_id = n.rowid - WHERE %s - ORDER BY m.sort_seq ASC - `, tableName, strings.Join(conditions, " AND ")) - - if limit > 0 { - query += fmt.Sprintf(" LIMIT %d", limit) - if offset > 0 { - query += fmt.Sprintf(" OFFSET %d", offset) - } - } - - // 执行查询 - rows, err := db.QueryContext(ctx, query, args...) - if err != nil { - return nil, errors.QueryFailed(query, err) - } - defer rows.Close() - - // 处理查询结果 - messages := []*model.Message{} - - for rows.Next() { - var msg model.MessageV4 - err := rows.Scan( - &msg.SortSeq, - &msg.ServerID, - &msg.LocalType, - &msg.UserName, - &msg.CreateTime, - &msg.MessageContent, - &msg.PackedInfoData, - &msg.Status, - ) - if err != nil { - return nil, errors.ScanRowFailed(err) - } - - messages = append(messages, msg.Wrap(talker)) - } - - return messages, nil -} - -// getMessagesFromDB 从数据库获取消息 -func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, startTime, endTime time.Time, talker string) ([]*model.Message, error) { - // 构建表名 - _talkerMd5Bytes := md5.Sum([]byte(talker)) - talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:]) - tableName := "Msg_" + talkerMd5 - - // 检查表是否存在 - var exists bool - err := db.QueryRowContext(ctx, - "SELECT 1 FROM sqlite_master WHERE type='table' AND name=?", - tableName).Scan(&exists) - - if err != nil { - if err == sql.ErrNoRows { - // 表不存在,返回空结果 - return []*model.Message{}, nil - } - return nil, errors.QueryFailed("", err) - } - - // 构建查询条件 - conditions := []string{"create_time >= ? AND create_time <= ?"} - args := []interface{}{startTime.Unix(), endTime.Unix()} - - query := fmt.Sprintf(` - SELECT m.sort_seq, m.server_id, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status - FROM %s m - LEFT JOIN Name2Id n ON m.real_sender_id = n.rowid - WHERE %s - ORDER BY m.sort_seq ASC - `, tableName, strings.Join(conditions, " AND ")) - - // 执行查询 - rows, err := db.QueryContext(ctx, query, args...) - if err != nil { - // 如果表不存在,SQLite 会返回错误 - if strings.Contains(err.Error(), "no such table") { - return []*model.Message{}, nil - } - return nil, errors.QueryFailed(query, err) - } - defer rows.Close() - - // 处理查询结果 - messages := []*model.Message{} - - for rows.Next() { - var msg model.MessageV4 - err := rows.Scan( - &msg.SortSeq, - &msg.ServerID, - &msg.LocalType, - &msg.UserName, - &msg.CreateTime, - &msg.MessageContent, - &msg.PackedInfoData, - &msg.Status, - ) - if err != nil { - return nil, errors.ScanRowFailed(err) - } - - messages = append(messages, msg.Wrap(talker)) - } - - return messages, nil + return filteredMessages, nil } // 联系人 diff --git a/internal/wechatdb/datasource/windowsv3/datasource.go b/internal/wechatdb/datasource/windowsv3/datasource.go index eeeb17e..60713d2 100644 --- a/internal/wechatdb/datasource/windowsv3/datasource.go +++ b/internal/wechatdb/datasource/windowsv3/datasource.go @@ -2,9 +2,9 @@ package windowsv3 import ( "context" - "database/sql" "encoding/hex" "fmt" + "regexp" "sort" "strings" "time" @@ -16,6 +16,7 @@ import ( "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm" + "github.com/sjzar/chatlog/pkg/util" ) const ( @@ -221,21 +222,38 @@ func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []Mes return dbs } -// GetMessages 实现 DataSource 接口的 GetMessages 方法 -func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { +func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) { + if talker == "" { + return nil, errors.ErrTalkerEmpty + } + + // 解析talker参数,支持多个talker(以英文逗号分隔) + talkers := util.Str2List(talker, ",") + if len(talkers) == 0 { + return nil, errors.ErrTalkerEmpty + } + // 找到时间范围内的数据库文件 dbInfos := ds.getDBInfosForTimeRange(startTime, endTime) if len(dbInfos) == 0 { return nil, errors.TimeRangeNotFound(startTime, endTime) } - if len(dbInfos) == 1 { - // LIMIT 和 OFFSET 逻辑在单文件情况下可以直接在 SQL 里处理 - return ds.getMessagesSingleFile(ctx, dbInfos[0], startTime, endTime, talker, limit, offset) + // 解析sender参数,支持多个发送者(以英文逗号分隔) + senders := util.Str2List(sender, ",") + + // 预编译正则表达式(如果有keyword) + var regex *regexp.Regexp + if keyword != "" { + var err error + regex, err = regexp.Compile(keyword) + if err != nil { + return nil, errors.QueryFailed("invalid regex pattern", err) + } } // 从每个相关数据库中查询消息 - totalMessages := []*model.Message{} + filteredMessages := []*model.Message{} for _, dbInfo := range dbInfos { // 检查上下文是否已取消 @@ -249,172 +267,137 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T continue } - messages, err := ds.getMessagesFromDB(ctx, db, dbInfo, startTime, endTime, talker) - if err != nil { - log.Err(err).Msgf("从数据库 %s 获取消息失败", dbInfo.FilePath) - continue - } + // 对每个talker进行查询 + for _, talkerItem := range talkers { + // 构建查询条件 + conditions := []string{"Sequence >= ? AND Sequence <= ?"} + args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000} - totalMessages = append(totalMessages, messages...) + // 添加talker条件 + talkerID, ok := dbInfo.TalkerMap[talkerItem] + if ok { + conditions = append(conditions, "TalkerId = ?") + args = append(args, talkerID) + } else { + conditions = append(conditions, "StrTalker = ?") + args = append(args, talkerItem) + } - if limit+offset > 0 && len(totalMessages) >= limit+offset { - break + query := fmt.Sprintf(` + SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender, + Type, SubType, StrContent, CompressContent, BytesExtra + FROM MSG + WHERE %s + ORDER BY Sequence ASC + `, strings.Join(conditions, " AND ")) + + // 执行查询 + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + // 如果表不存在,跳过此talker + if strings.Contains(err.Error(), "no such table") { + continue + } + log.Err(err).Msgf("从数据库 %s 查询消息失败", dbInfo.FilePath) + continue + } + + // 处理查询结果,在读取时进行过滤 + for rows.Next() { + var msg model.MessageV3 + var compressContent []byte + var bytesExtra []byte + + err := rows.Scan( + &msg.MsgSvrID, + &msg.Sequence, + &msg.CreateTime, + &msg.StrTalker, + &msg.IsSender, + &msg.Type, + &msg.SubType, + &msg.StrContent, + &compressContent, + &bytesExtra, + ) + if err != nil { + rows.Close() + return nil, errors.ScanRowFailed(err) + } + msg.CompressContent = compressContent + msg.BytesExtra = bytesExtra + + // 将消息转换为标准格式 + message := msg.Wrap() + + // 应用sender过滤 + if len(senders) > 0 { + senderMatch := false + for _, s := range senders { + if message.Sender == s { + senderMatch = true + break + } + } + if !senderMatch { + continue // 不匹配sender,跳过此消息 + } + } + + // 应用keyword过滤 + if regex != nil { + plainText := message.PlainTextContent() + if !regex.MatchString(plainText) { + continue // 不匹配keyword,跳过此消息 + } + } + + // 通过所有过滤条件,保留此消息 + filteredMessages = append(filteredMessages, message) + + // 检查是否已经满足分页处理数量 + if limit > 0 && len(filteredMessages) >= offset+limit { + // 已经获取了足够的消息,可以提前返回 + rows.Close() + + // 对所有消息按时间排序 + sort.Slice(filteredMessages, func(i, j int) bool { + return filteredMessages[i].Seq < filteredMessages[j].Seq + }) + + // 处理分页 + if offset >= len(filteredMessages) { + return []*model.Message{}, nil + } + end := offset + limit + if end > len(filteredMessages) { + end = len(filteredMessages) + } + return filteredMessages[offset:end], nil + } + } + rows.Close() } } // 对所有消息按时间排序 - sort.Slice(totalMessages, func(i, j int) bool { - return totalMessages[i].Seq < totalMessages[j].Seq + sort.Slice(filteredMessages, func(i, j int) bool { + return filteredMessages[i].Seq < filteredMessages[j].Seq }) // 处理分页 if limit > 0 { - if offset >= len(totalMessages) { + if offset >= len(filteredMessages) { return []*model.Message{}, nil } end := offset + limit - if end > len(totalMessages) { - end = len(totalMessages) + if end > len(filteredMessages) { + end = len(filteredMessages) } - return totalMessages[offset:end], nil + return filteredMessages[offset:end], nil } - return totalMessages, nil -} - -// getMessagesSingleFile 从单个数据库文件获取消息 -func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { - db, err := ds.dbm.OpenDB(dbInfo.FilePath) - if err != nil { - return nil, errors.DBConnectFailed(dbInfo.FilePath, nil) - } - - // 构建查询条件 - conditions := []string{"Sequence >= ? AND Sequence <= ?"} - args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000} - if len(talker) > 0 { - // TalkerId 有索引,优先使用 - talkerID, ok := dbInfo.TalkerMap[talker] - if ok { - conditions = append(conditions, "TalkerId = ?") - args = append(args, talkerID) - } else { - conditions = append(conditions, "StrTalker = ?") - args = append(args, talker) - } - } - query := fmt.Sprintf(` - SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender, - Type, SubType, StrContent, CompressContent, BytesExtra - FROM MSG - WHERE %s - ORDER BY Sequence ASC - `, strings.Join(conditions, " AND ")) - - if limit > 0 { - query += fmt.Sprintf(" LIMIT %d", limit) - - if offset > 0 { - query += fmt.Sprintf(" OFFSET %d", offset) - } - } - - // 执行查询 - rows, err := db.QueryContext(ctx, query, args...) - if err != nil { - return nil, errors.QueryFailed(query, err) - } - defer rows.Close() - - // 处理查询结果 - totalMessages := []*model.Message{} - for rows.Next() { - var msg model.MessageV3 - var compressContent []byte - var bytesExtra []byte - err := rows.Scan( - &msg.MsgSvrID, - &msg.Sequence, - &msg.CreateTime, - &msg.StrTalker, - &msg.IsSender, - &msg.Type, - &msg.SubType, - &msg.StrContent, - &compressContent, - &bytesExtra, - ) - if err != nil { - return nil, errors.ScanRowFailed(err) - } - msg.CompressContent = compressContent - msg.BytesExtra = bytesExtra - totalMessages = append(totalMessages, msg.Wrap()) - } - return totalMessages, nil -} - -// getMessagesFromDB 从数据库获取消息 -func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string) ([]*model.Message, error) { - // 构建查询条件 - conditions := []string{"Sequence >= ? AND Sequence <= ?"} - args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000} - - if len(talker) > 0 { - talkerID, ok := dbInfo.TalkerMap[talker] - if ok { - conditions = append(conditions, "TalkerId = ?") - args = append(args, talkerID) - } else { - conditions = append(conditions, "StrTalker = ?") - args = append(args, talker) - } - } - - query := fmt.Sprintf(` - SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender, - Type, SubType, StrContent, CompressContent, BytesExtra - FROM MSG - WHERE %s - ORDER BY Sequence ASC - `, strings.Join(conditions, " AND ")) - - // 执行查询 - rows, err := db.QueryContext(ctx, query, args...) - if err != nil { - return nil, errors.QueryFailed(query, err) - } - defer rows.Close() - - // 处理查询结果 - messages := []*model.Message{} - for rows.Next() { - var msg model.MessageV3 - var compressContent []byte - var bytesExtra []byte - - err := rows.Scan( - &msg.MsgSvrID, - &msg.Sequence, - &msg.CreateTime, - &msg.StrTalker, - &msg.IsSender, - &msg.Type, - &msg.SubType, - &msg.StrContent, - &compressContent, - &bytesExtra, - ) - if err != nil { - return nil, errors.ScanRowFailed(err) - } - msg.CompressContent = compressContent - msg.BytesExtra = bytesExtra - - messages = append(messages, msg.Wrap()) - } - - return messages, nil + return filteredMessages, nil } // GetContacts 实现获取联系人信息的方法 diff --git a/internal/wechatdb/repository/chatroom.go b/internal/wechatdb/repository/chatroom.go index 04849f3..c69f6b0 100644 --- a/internal/wechatdb/repository/chatroom.go +++ b/internal/wechatdb/repository/chatroom.go @@ -18,8 +18,8 @@ func (r *Repository) initChatRoomCache(ctx context.Context) error { } chatRoomMap := make(map[string]*model.ChatRoom) - remarkToChatRoom := make(map[string]*model.ChatRoom) - nickNameToChatRoom := make(map[string]*model.ChatRoom) + remarkToChatRoom := make(map[string][]*model.ChatRoom) + nickNameToChatRoom := make(map[string][]*model.ChatRoom) chatRoomList := make([]string, 0) chatRoomRemark := make([]string, 0) chatRoomNickName := make([]string, 0) @@ -30,11 +30,21 @@ func (r *Repository) initChatRoomCache(ctx context.Context) error { chatRoomMap[chatRoom.Name] = chatRoom chatRoomList = append(chatRoomList, chatRoom.Name) if chatRoom.Remark != "" { - remarkToChatRoom[chatRoom.Remark] = chatRoom + remark, ok := remarkToChatRoom[chatRoom.Remark] + if !ok { + remark = make([]*model.ChatRoom, 0) + } + remark = append(remark, chatRoom) + remarkToChatRoom[chatRoom.Remark] = remark chatRoomRemark = append(chatRoomRemark, chatRoom.Remark) } if chatRoom.NickName != "" { - nickNameToChatRoom[chatRoom.NickName] = chatRoom + nickName, ok := nickNameToChatRoom[chatRoom.NickName] + if !ok { + nickName = make([]*model.ChatRoom, 0) + } + nickName = append(nickName, chatRoom) + nickNameToChatRoom[chatRoom.NickName] = nickName chatRoomNickName = append(chatRoomNickName, chatRoom.NickName) } } @@ -49,11 +59,21 @@ func (r *Repository) initChatRoomCache(ctx context.Context) error { chatRoomMap[contact.UserName] = chatRoom chatRoomList = append(chatRoomList, contact.UserName) if contact.Remark != "" { - remarkToChatRoom[contact.Remark] = chatRoom + remark, ok := remarkToChatRoom[chatRoom.Remark] + if !ok { + remark = make([]*model.ChatRoom, 0) + } + remark = append(remark, chatRoom) + remarkToChatRoom[chatRoom.Remark] = remark chatRoomRemark = append(chatRoomRemark, contact.Remark) } if contact.NickName != "" { - nickNameToChatRoom[contact.NickName] = chatRoom + nickName, ok := nickNameToChatRoom[chatRoom.NickName] + if !ok { + nickName = make([]*model.ChatRoom, 0) + } + nickName = append(nickName, chatRoom) + nickNameToChatRoom[chatRoom.NickName] = nickName chatRoomNickName = append(chatRoomNickName, contact.NickName) } } @@ -63,9 +83,12 @@ func (r *Repository) initChatRoomCache(ctx context.Context) error { sort.Strings(chatRoomNickName) r.chatRoomCache = chatRoomMap - r.chatRoomList = chatRoomList r.remarkToChatRoom = remarkToChatRoom r.nickNameToChatRoom = nickNameToChatRoom + r.chatRoomList = chatRoomList + r.chatRoomRemark = chatRoomRemark + r.chatRoomNickName = chatRoomNickName + return nil } @@ -75,7 +98,7 @@ func (r *Repository) GetChatRooms(ctx context.Context, key string, limit, offset if key != "" { ret = r.findChatRooms(key) if len(ret) == 0 { - return nil, errors.ChatRoomNotFound(key) + return []*model.ChatRoom{}, nil } if limit > 0 { @@ -129,21 +152,21 @@ func (r *Repository) findChatRoom(key string) *model.ChatRoom { return chatRoom } if chatRoom, ok := r.remarkToChatRoom[key]; ok { - return chatRoom + return chatRoom[0] } if chatRoom, ok := r.nickNameToChatRoom[key]; ok { - return chatRoom + return chatRoom[0] } // Contain for _, remark := range r.chatRoomRemark { if strings.Contains(remark, key) { - return r.remarkToChatRoom[remark] + return r.remarkToChatRoom[remark][0] } } for _, nickName := range r.chatRoomNickName { if strings.Contains(nickName, key) { - return r.nickNameToChatRoom[nickName] + return r.nickNameToChatRoom[nickName][0] } } @@ -157,26 +180,42 @@ func (r *Repository) findChatRooms(key string) []*model.ChatRoom { ret = append(ret, chatRoom) distinct[chatRoom.Name] = true } - if chatRoom, ok := r.remarkToChatRoom[key]; ok && !distinct[chatRoom.Name] { - ret = append(ret, chatRoom) - distinct[chatRoom.Name] = true + if chatRooms, ok := r.remarkToChatRoom[key]; ok { + for _, chatRoom := range chatRooms { + if !distinct[chatRoom.Name] { + ret = append(ret, chatRoom) + distinct[chatRoom.Name] = true + } + } } - if chatRoom, ok := r.nickNameToChatRoom[key]; ok && !distinct[chatRoom.Name] { - ret = append(ret, chatRoom) - distinct[chatRoom.Name] = true + if chatRooms, ok := r.nickNameToChatRoom[key]; ok { + for _, chatRoom := range chatRooms { + if !distinct[chatRoom.Name] { + ret = append(ret, chatRoom) + distinct[chatRoom.Name] = true + } + } } // Contain for _, remark := range r.chatRoomRemark { - if strings.Contains(remark, key) && !distinct[r.remarkToChatRoom[remark].Name] { - ret = append(ret, r.remarkToChatRoom[remark]) - distinct[r.remarkToChatRoom[remark].Name] = true + if strings.Contains(remark, key) { + for _, chatRoom := range r.remarkToChatRoom[remark] { + if !distinct[chatRoom.Name] { + ret = append(ret, chatRoom) + distinct[chatRoom.Name] = true + } + } } } for _, nickName := range r.chatRoomNickName { - if strings.Contains(nickName, key) && !distinct[r.nickNameToChatRoom[nickName].Name] { - ret = append(ret, r.nickNameToChatRoom[nickName]) - distinct[r.nickNameToChatRoom[nickName].Name] = true + if strings.Contains(nickName, key) { + for _, chatRoom := range r.nickNameToChatRoom[nickName] { + if !distinct[chatRoom.Name] { + ret = append(ret, chatRoom) + distinct[chatRoom.Name] = true + } + } } } diff --git a/internal/wechatdb/repository/contact.go b/internal/wechatdb/repository/contact.go index 5d641fa..63902fc 100644 --- a/internal/wechatdb/repository/contact.go +++ b/internal/wechatdb/repository/contact.go @@ -18,9 +18,9 @@ func (r *Repository) initContactCache(ctx context.Context) error { } contactMap := make(map[string]*model.Contact) - aliasMap := make(map[string]*model.Contact) - remarkMap := make(map[string]*model.Contact) - nickNameMap := make(map[string]*model.Contact) + aliasMap := make(map[string][]*model.Contact) + remarkMap := make(map[string][]*model.Contact) + nickNameMap := make(map[string][]*model.Contact) chatRoomUserMap := make(map[string]*model.Contact) chatRoomInContactMap := make(map[string]*model.Contact) contactList := make([]string, 0) @@ -34,15 +34,30 @@ func (r *Repository) initContactCache(ctx context.Context) error { // 建立快速查找索引 if contact.Alias != "" { - aliasMap[contact.Alias] = contact + alias, ok := aliasMap[contact.Alias] + if !ok { + alias = make([]*model.Contact, 0) + } + alias = append(alias, contact) + aliasMap[contact.Alias] = alias aliasList = append(aliasList, contact.Alias) } if contact.Remark != "" { - remarkMap[contact.Remark] = contact + remark, ok := remarkMap[contact.Remark] + if !ok { + remark = make([]*model.Contact, 0) + } + remark = append(remark, contact) + remarkMap[contact.Remark] = remark remarkList = append(remarkList, contact.Remark) } if contact.NickName != "" { - nickNameMap[contact.NickName] = contact + nickName, ok := nickNameMap[contact.NickName] + if !ok { + nickName = make([]*model.Contact, 0) + } + nickName = append(nickName, contact) + nickNameMap[contact.NickName] = nickName nickNameList = append(nickNameList, contact.NickName) } @@ -88,7 +103,7 @@ func (r *Repository) GetContacts(ctx context.Context, key string, limit, offset if key != "" { ret = r.findContacts(key) if len(ret) == 0 { - return nil, errors.ContactNotFound(key) + return []*model.Contact{}, nil } if limit > 0 { end := offset + limit @@ -124,29 +139,29 @@ func (r *Repository) findContact(key string) *model.Contact { return contact } if contact, ok := r.aliasToContact[key]; ok { - return contact + return contact[0] } if contact, ok := r.remarkToContact[key]; ok { - return contact + return contact[0] } if contact, ok := r.nickNameToContact[key]; ok { - return contact + return contact[0] } // Contain for _, alias := range r.aliasList { if strings.Contains(alias, key) { - return r.aliasToContact[alias] + return r.aliasToContact[alias][0] } } for _, remark := range r.remarkList { if strings.Contains(remark, key) { - return r.remarkToContact[remark] + return r.remarkToContact[remark][0] } } for _, nickName := range r.nickNameList { if strings.Contains(nickName, key) { - return r.nickNameToContact[nickName] + return r.nickNameToContact[nickName][0] } } return nil @@ -159,37 +174,62 @@ func (r *Repository) findContacts(key string) []*model.Contact { ret = append(ret, contact) distinct[contact.UserName] = true } - if contact, ok := r.aliasToContact[key]; ok && !distinct[contact.UserName] { - ret = append(ret, contact) - distinct[contact.UserName] = true + if contacts, ok := r.aliasToContact[key]; ok { + for _, contact := range contacts { + if !distinct[contact.UserName] { + ret = append(ret, contact) + distinct[contact.UserName] = true + } + } } - if contact, ok := r.remarkToContact[key]; ok && !distinct[contact.UserName] { - ret = append(ret, contact) - distinct[contact.UserName] = true + if contacts, ok := r.remarkToContact[key]; ok { + for _, contact := range contacts { + if !distinct[contact.UserName] { + ret = append(ret, contact) + distinct[contact.UserName] = true + } + } } - if contact, ok := r.nickNameToContact[key]; ok && !distinct[contact.UserName] { - ret = append(ret, contact) - distinct[contact.UserName] = true + if contacts, ok := r.nickNameToContact[key]; ok { + for _, contact := range contacts { + if !distinct[contact.UserName] { + ret = append(ret, contact) + distinct[contact.UserName] = true + } + } } // Contain for _, alias := range r.aliasList { - if strings.Contains(alias, key) && !distinct[r.aliasToContact[alias].UserName] { - ret = append(ret, r.aliasToContact[alias]) - distinct[r.aliasToContact[alias].UserName] = true + if strings.Contains(alias, key) { + for _, contact := range r.aliasToContact[alias] { + if !distinct[contact.UserName] { + ret = append(ret, contact) + distinct[contact.UserName] = true + } + } } } for _, remark := range r.remarkList { - if strings.Contains(remark, key) && !distinct[r.remarkToContact[remark].UserName] { - ret = append(ret, r.remarkToContact[remark]) - distinct[r.remarkToContact[remark].UserName] = true + if strings.Contains(remark, key) { + for _, contact := range r.remarkToContact[remark] { + if !distinct[contact.UserName] { + ret = append(ret, contact) + distinct[contact.UserName] = true + } + } } } for _, nickName := range r.nickNameList { - if strings.Contains(nickName, key) && !distinct[r.nickNameToContact[nickName].UserName] { - ret = append(ret, r.nickNameToContact[nickName]) - distinct[r.nickNameToContact[nickName].UserName] = true + if strings.Contains(nickName, key) { + for _, contact := range r.nickNameToContact[nickName] { + if !distinct[contact.UserName] { + ret = append(ret, contact) + distinct[contact.UserName] = true + } + } } } + return ret } diff --git a/internal/wechatdb/repository/message.go b/internal/wechatdb/repository/message.go index 199142d..257b07f 100644 --- a/internal/wechatdb/repository/message.go +++ b/internal/wechatdb/repository/message.go @@ -2,23 +2,20 @@ package repository import ( "context" + "strings" "time" "github.com/sjzar/chatlog/internal/model" + "github.com/sjzar/chatlog/pkg/util" "github.com/rs/zerolog/log" ) // GetMessages 实现 Repository 接口的 GetMessages 方法 -func (r *Repository) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { +func (r *Repository) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) { - if contact, _ := r.GetContact(ctx, talker); contact != nil { - talker = contact.UserName - } else if chatRoom, _ := r.GetChatRoom(ctx, talker); chatRoom != nil { - talker = chatRoom.Name - } - - messages, err := r.ds.GetMessages(ctx, startTime, endTime, talker, limit, offset) + talker, sender = r.parseTalkerAndSender(ctx, talker, sender) + messages, err := r.ds.GetMessages(ctx, startTime, endTime, talker, sender, keyword, limit, offset) if err != nil { return nil, err } @@ -62,3 +59,53 @@ func (r *Repository) enrichMessage(msg *model.Message) { } } } + +func (r *Repository) parseTalkerAndSender(ctx context.Context, talker, sender string) (string, string) { + displayName2User := make(map[string]string) + users := make(map[string]bool) + + talkers := util.Str2List(talker, ",") + if len(talkers) > 0 { + for i := 0; i < len(talkers); i++ { + if contact, _ := r.GetContact(ctx, talkers[i]); contact != nil { + talkers[i] = contact.UserName + } else if chatRoom, _ := r.GetChatRoom(ctx, talker); chatRoom != nil { + talkers[i] = chatRoom.Name + } + } + // 获取群聊的用户列表 + for i := 0; i < len(talkers); i++ { + if chatRoom, _ := r.GetChatRoom(ctx, talkers[i]); chatRoom != nil { + for user, displayName := range chatRoom.User2DisplayName { + displayName2User[displayName] = user + } + for _, user := range chatRoom.Users { + users[user.UserName] = true + } + } + } + talker = strings.Join(talkers, ",") + } + + senders := util.Str2List(sender, ",") + if len(senders) > 0 { + for i := 0; i < len(senders); i++ { + if user, ok := displayName2User[senders[i]]; ok { + senders[i] = user + } else { + // FIXME 大量群聊用户名称重复,无法直接通过 GetContact 获取 ID,后续再优化 + for user := range users { + if contact := r.getFullContact(user); contact != nil { + if contact.DisplayName() == senders[i] { + senders[i] = user + break + } + } + } + } + } + sender = strings.Join(senders, ",") + } + + return talker, sender +} diff --git a/internal/wechatdb/repository/repository.go b/internal/wechatdb/repository/repository.go index 5cdcbca..bf26803 100644 --- a/internal/wechatdb/repository/repository.go +++ b/internal/wechatdb/repository/repository.go @@ -17,9 +17,9 @@ type Repository struct { // Cache for contact contactCache map[string]*model.Contact - aliasToContact map[string]*model.Contact - remarkToContact map[string]*model.Contact - nickNameToContact map[string]*model.Contact + aliasToContact map[string][]*model.Contact + remarkToContact map[string][]*model.Contact + nickNameToContact map[string][]*model.Contact chatRoomInContact map[string]*model.Contact contactList []string aliasList []string @@ -28,8 +28,8 @@ type Repository struct { // Cache for chat room chatRoomCache map[string]*model.ChatRoom - remarkToChatRoom map[string]*model.ChatRoom - nickNameToChatRoom map[string]*model.ChatRoom + remarkToChatRoom map[string][]*model.ChatRoom + nickNameToChatRoom map[string][]*model.ChatRoom chatRoomList []string chatRoomRemark []string chatRoomNickName []string @@ -43,17 +43,17 @@ func New(ds datasource.DataSource) (*Repository, error) { r := &Repository{ ds: ds, contactCache: make(map[string]*model.Contact), - aliasToContact: make(map[string]*model.Contact), - remarkToContact: make(map[string]*model.Contact), - nickNameToContact: make(map[string]*model.Contact), + aliasToContact: make(map[string][]*model.Contact), + remarkToContact: make(map[string][]*model.Contact), + nickNameToContact: make(map[string][]*model.Contact), chatRoomUserToInfo: make(map[string]*model.Contact), contactList: make([]string, 0), aliasList: make([]string, 0), remarkList: make([]string, 0), nickNameList: make([]string, 0), chatRoomCache: make(map[string]*model.ChatRoom), - remarkToChatRoom: make(map[string]*model.ChatRoom), - nickNameToChatRoom: make(map[string]*model.ChatRoom), + remarkToChatRoom: make(map[string][]*model.ChatRoom), + nickNameToChatRoom: make(map[string][]*model.ChatRoom), chatRoomList: make([]string, 0), chatRoomRemark: make([]string, 0), chatRoomNickName: make([]string, 0), diff --git a/internal/wechatdb/wechatdb.go b/internal/wechatdb/wechatdb.go index 3c80cb7..2e90ddf 100644 --- a/internal/wechatdb/wechatdb.go +++ b/internal/wechatdb/wechatdb.go @@ -57,11 +57,11 @@ func (w *DB) Initialize() error { return nil } -func (w *DB) GetMessages(start, end time.Time, talker string, limit, offset int) ([]*model.Message, error) { +func (w *DB) GetMessages(start, end time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) { ctx := context.Background() // 使用 repository 获取消息 - messages, err := w.repo.GetMessages(ctx, start, end, talker, limit, offset) + messages, err := w.repo.GetMessages(ctx, start, end, talker, sender, keyword, limit, offset) if err != nil { return nil, err } diff --git a/pkg/util/strings.go b/pkg/util/strings.go index 1aaa9cb..804d073 100644 --- a/pkg/util/strings.go +++ b/pkg/util/strings.go @@ -3,6 +3,7 @@ package util import ( "fmt" "strconv" + "strings" "unicode" "unicode/utf8" ) @@ -45,3 +46,26 @@ func IsNumeric(s string) bool { func SplitInt64ToTwoInt32(input int64) (int64, int64) { return input & 0xFFFFFFFF, input >> 32 } + +func Str2List(str string, sep string) []string { + list := make([]string, 0) + + if str == "" { + return list + } + + listMap := make(map[string]bool) + for _, elem := range strings.Split(str, sep) { + elem = strings.TrimSpace(elem) + if len(elem) == 0 { + continue + } + if _, ok := listMap[elem]; ok { + continue + } + listMap[elem] = true + list = append(list, elem) + } + + return list +} diff --git a/pkg/util/time.go b/pkg/util/time.go index cc63085..953c6d9 100644 --- a/pkg/util/time.go +++ b/pkg/util/time.go @@ -582,8 +582,8 @@ func adjustStartTime(t time.Time, g TimeGranularity) time.Time { func adjustEndTime(t time.Time, g TimeGranularity) time.Time { switch g { case GranularitySecond, GranularityMinute, GranularityHour: - // 对于精确到秒/分钟/小时的时间,设置为当天结束 - return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, t.Location()) + // 对于精确到秒/分钟/小时的时间,保持原样 + return t case GranularityDay: // 精确到天,设置为当天结束 return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, t.Location()) @@ -634,3 +634,25 @@ func isValidDate(year, month, day int) bool { return day <= daysInMonth } + +func PerfectTimeFormat(start time.Time, end time.Time) string { + endTime := end + + // 如果结束时间是某一天的 0 点整,将其减去 1 秒,视为前一天的结束 + if endTime.Hour() == 0 && endTime.Minute() == 0 && endTime.Second() == 0 && endTime.Nanosecond() == 0 { + endTime = endTime.Add(-time.Second) // 减去 1 秒 + } + + // 判断是否跨年 + if start.Year() != endTime.Year() { + return "2006-01-02 15:04:05" // 完整格式,包含年月日时分秒 + } + + // 判断是否跨天(但在同一年内) + if start.YearDay() != endTime.YearDay() { + return "01-02 15:04:05" // 月日时分秒格式 + } + + // 在同一天内 + return "15:04:05" // 只显示时分秒 +}