3 Commits

Author SHA1 Message Date
Sarv
d124086e70 support thumb url (#62) 2025-04-19 18:30:41 +08:00
Sarv
a745519451 message search (#60) 2025-04-19 03:06:05 +08:00
Sarv
85b5465d2a fix db file dependencies (#56) 2025-04-18 00:31:19 +08:00
26 changed files with 1057 additions and 561 deletions

View File

@@ -158,8 +158,20 @@ GET /api/v1/chatlog?time=2023-01-01&talker=wxid_xxx
- **联系人列表**`GET /api/v1/contact` - **联系人列表**`GET /api/v1/contact`
- **群聊列表**`GET /api/v1/chatroom` - **群聊列表**`GET /api/v1/chatroom`
- **会话列表**`GET /api/v1/session` - **会话列表**`GET /api/v1/session`
- **多媒体内容**`GET /api/v1/media?msgid=xxx`
### 多媒体内容
聊天记录中的多媒体内容会通过 HTTP 服务进行提供,可通过以下路径访问:
- **图片内容**`GET /image/<id>`
- **视频内容**`GET /video/<id>`
- **文件内容**`GET /file/<id>`
- **语音内容**`GET /voice/<id>`
- **多媒体内容**`GET /data/<data dir relative path>`
当请求图片、视频、文件内容时,将返回 302 跳转到多媒体内容 URL。
当请求语音内容时,将直接返回语音内容,并对原始 SILK 语音做了实时转码 MP3 处理。
多媒体内容 URL 地址为基于`数据目录`的相对地址,请求多媒体内容将直接返回对应文件,并针对加密图片做了实时解密处理。
## MCP 集成 ## MCP 集成

View File

@@ -40,8 +40,8 @@ func (s *Service) GetDB() *wechatdb.DB {
return s.db return s.db
} }
func (s *Service) GetMessages(start, end time.Time, talker string, limit, offset int) ([]*model.Message, error) { func (s *Service) GetMessages(start, end time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) {
return s.db.GetMessages(start, end, talker, limit, offset) return s.db.GetMessages(start, end, talker, sender, keyword, limit, offset)
} }
func (s *Service) GetContacts(key string, limit, offset int) (*wechatdb.GetContactsResp, error) { func (s *Service) GetContacts(key string, limit, offset int) (*wechatdb.GetContactsResp, error) {

View File

@@ -32,10 +32,10 @@ func (s *Service) initRouter() {
router.StaticFileFS("/", "./index.htm", http.FS(staticDir)) router.StaticFileFS("/", "./index.htm", http.FS(staticDir))
// Media // Media
router.GET("/image/:key", s.GetImage) router.GET("/image/*key", s.GetImage)
router.GET("/video/:key", s.GetVideo) router.GET("/video/*key", s.GetVideo)
router.GET("/file/:key", s.GetFile) router.GET("/file/*key", s.GetFile)
router.GET("/voice/:key", s.GetVoice) router.GET("/voice/*key", s.GetVoice)
router.GET("/data/*path", s.GetMediaData) router.GET("/data/*path", s.GetMediaData)
// MCP Server // MCP Server
@@ -75,11 +75,13 @@ func (s *Service) NoRoute(c *gin.Context) {
func (s *Service) GetChatlog(c *gin.Context) { func (s *Service) GetChatlog(c *gin.Context) {
q := struct { q := struct {
Time string `form:"time"` Time string `form:"time"`
Talker string `form:"talker"` Talker string `form:"talker"`
Limit int `form:"limit"` Sender string `form:"sender"`
Offset int `form:"offset"` Keyword string `form:"keyword"`
Format string `form:"format"` Limit int `form:"limit"`
Offset int `form:"offset"`
Format string `form:"format"`
}{} }{}
if err := c.BindQuery(&q); err != nil { if err := c.BindQuery(&q); err != nil {
@@ -100,7 +102,7 @@ func (s *Service) GetChatlog(c *gin.Context) {
q.Offset = 0 q.Offset = 0
} }
messages, err := s.db.GetMessages(start, end, q.Talker, q.Limit, q.Offset) messages, err := s.db.GetMessages(start, end, q.Talker, q.Sender, q.Keyword, q.Limit, q.Offset)
if err != nil { if err != nil {
errors.Err(c, err) errors.Err(c, err)
return return
@@ -119,7 +121,7 @@ func (s *Service) GetChatlog(c *gin.Context) {
c.Writer.Flush() c.Writer.Flush()
for _, m := range messages { for _, m := range messages {
c.Writer.WriteString(m.PlainText(len(q.Talker) == 0, c.Request.Host)) c.Writer.WriteString(m.PlainText(strings.Contains(q.Talker, ","), util.PerfectTimeFormat(start, end), c.Request.Host))
c.Writer.WriteString("\n") c.Writer.WriteString("\n")
c.Writer.Flush() c.Writer.Flush()
} }
@@ -129,10 +131,10 @@ func (s *Service) GetChatlog(c *gin.Context) {
func (s *Service) GetContacts(c *gin.Context) { func (s *Service) GetContacts(c *gin.Context) {
q := struct { q := struct {
Key string `form:"key"` Keyword string `form:"keyword"`
Limit int `form:"limit"` Limit int `form:"limit"`
Offset int `form:"offset"` Offset int `form:"offset"`
Format string `form:"format"` Format string `form:"format"`
}{} }{}
if err := c.BindQuery(&q); err != nil { if err := c.BindQuery(&q); err != nil {
@@ -140,7 +142,7 @@ func (s *Service) GetContacts(c *gin.Context) {
return return
} }
list, err := s.db.GetContacts(q.Key, q.Limit, q.Offset) list, err := s.db.GetContacts(q.Keyword, q.Limit, q.Offset)
if err != nil { if err != nil {
errors.Err(c, err) errors.Err(c, err)
return return
@@ -174,10 +176,10 @@ func (s *Service) GetContacts(c *gin.Context) {
func (s *Service) GetChatRooms(c *gin.Context) { func (s *Service) GetChatRooms(c *gin.Context) {
q := struct { q := struct {
Key string `form:"key"` Keyword string `form:"keyword"`
Limit int `form:"limit"` Limit int `form:"limit"`
Offset int `form:"offset"` Offset int `form:"offset"`
Format string `form:"format"` Format string `form:"format"`
}{} }{}
if err := c.BindQuery(&q); err != nil { if err := c.BindQuery(&q); err != nil {
@@ -185,7 +187,7 @@ func (s *Service) GetChatRooms(c *gin.Context) {
return return
} }
list, err := s.db.GetChatRooms(q.Key, q.Limit, q.Offset) list, err := s.db.GetChatRooms(q.Keyword, q.Limit, q.Offset)
if err != nil { if err != nil {
errors.Err(c, err) errors.Err(c, err)
return return
@@ -218,10 +220,10 @@ func (s *Service) GetChatRooms(c *gin.Context) {
func (s *Service) GetSessions(c *gin.Context) { func (s *Service) GetSessions(c *gin.Context) {
q := struct { q := struct {
Key string `form:"key"` Keyword string `form:"keyword"`
Limit int `form:"limit"` Limit int `form:"limit"`
Offset int `form:"offset"` Offset int `form:"offset"`
Format string `form:"format"` Format string `form:"format"`
}{} }{}
if err := c.BindQuery(&q); err != nil { if err := c.BindQuery(&q); err != nil {
@@ -229,7 +231,7 @@ func (s *Service) GetSessions(c *gin.Context) {
return return
} }
sessions, err := s.db.GetSessions(q.Key, q.Limit, q.Offset) sessions, err := s.db.GetSessions(q.Keyword, q.Limit, q.Offset)
if err != nil { if err != nil {
errors.Err(c, err) errors.Err(c, err)
return return
@@ -279,30 +281,51 @@ func (s *Service) GetVoice(c *gin.Context) {
} }
func (s *Service) GetMedia(c *gin.Context, _type string) { func (s *Service) GetMedia(c *gin.Context, _type string) {
key := c.Param("key") key := strings.TrimPrefix(c.Param("key"), "/")
if key == "" { if key == "" {
errors.Err(c, errors.InvalidArg(key)) errors.Err(c, errors.InvalidArg(key))
return return
} }
media, err := s.db.GetMedia(_type, key) keys := util.Str2List(key, ",")
if err != nil { if len(keys) == 0 {
errors.Err(c, err) errors.Err(c, errors.InvalidArg(key))
return return
} }
if c.Query("info") != "" { var _err error
c.JSON(http.StatusOK, media) for _, k := range keys {
if len(k) != 32 {
absolutePath := filepath.Join(s.ctx.DataDir, k)
if _, err := os.Stat(absolutePath); os.IsNotExist(err) {
continue
}
c.Redirect(http.StatusFound, "/data/"+k)
return
}
media, err := s.db.GetMedia(_type, k)
if err != nil {
_err = err
continue
}
if c.Query("info") != "" {
c.JSON(http.StatusOK, media)
return
}
switch media.Type {
case "voice":
s.HandleVoice(c, media.Data)
return
default:
c.Redirect(http.StatusFound, "/data/"+media.Path)
return
}
}
if _err != nil {
errors.Err(c, _err)
return return
} }
switch media.Type {
case "voice":
s.HandleVoice(c, media.Data)
default:
c.Redirect(http.StatusFound, "/data/"+media.Path)
}
} }
func (s *Service) GetMediaData(c *gin.Context) { func (s *Service) GetMediaData(c *gin.Context) {
@@ -351,7 +374,8 @@ func (s *Service) HandleDatFile(c *gin.Context, path string) {
case "bmp": case "bmp":
c.Data(http.StatusOK, "image/bmp", out) c.Data(http.StatusOK, "image/bmp", out)
default: default:
c.File(path) c.Data(http.StatusOK, "image/jpg", out)
// c.File(path)
} }
} }

View File

@@ -377,12 +377,12 @@
</p> </p>
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="chatroom-query" <label for="chatroom-keyword"
>搜索群聊:<span class="optional-param">可选</span></label >搜索群聊:<span class="optional-param">可选</span></label
> >
<input <input
type="text" type="text"
id="chatroom-query" id="chatroom-keyword"
placeholder="输入关键词搜索群聊" placeholder="输入关键词搜索群聊"
/> />
</div> </div>
@@ -408,12 +408,12 @@
</p> </p>
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="contact-query" <label for="contact-keyword"
>搜索联系人:<span class="optional-param">可选</span></label >搜索联系人:<span class="optional-param">可选</span></label
> >
<input <input
type="text" type="text"
id="contact-query" id="contact-keyword"
placeholder="输入关键词搜索联系人" placeholder="输入关键词搜索联系人"
/> />
</div> </div>
@@ -459,6 +459,26 @@
placeholder="wxid、群ID、备注名或昵称" placeholder="wxid、群ID、备注名或昵称"
/> />
</div> </div>
<div class="form-group">
<label for="sender"
>发送者:<span class="optional-param">可选</span></label
>
<input
type="text"
id="sender"
placeholder="指定消息发送者"
/>
</div>
<div class="form-group">
<label for="keyword"
>关键词:<span class="optional-param">可选</span></label
>
<input
type="text"
id="keyword"
placeholder="搜索消息内容中的关键词"
/>
</div>
<div class="form-group"> <div class="form-group">
<label for="limit" <label for="limit"
>返回数量:<span class="optional-param">可选</span></label >返回数量:<span class="optional-param">可选</span></label
@@ -603,6 +623,8 @@
url += "chatlog"; url += "chatlog";
const time = document.getElementById("time").value; const time = document.getElementById("time").value;
const talker = document.getElementById("talker").value; const talker = document.getElementById("talker").value;
const sender = document.getElementById("sender").value;
const keyword = document.getElementById("keyword").value;
const limit = document.getElementById("limit").value; const limit = document.getElementById("limit").value;
const offset = document.getElementById("offset").value; const offset = document.getElementById("offset").value;
const format = document.getElementById("format").value; const format = document.getElementById("format").value;
@@ -617,6 +639,8 @@
if (time) params.append("time", time); if (time) params.append("time", time);
if (talker) params.append("talker", talker); if (talker) params.append("talker", talker);
if (sender) params.append("sender", sender);
if (keyword) params.append("keyword", keyword);
if (limit) params.append("limit", limit); if (limit) params.append("limit", limit);
if (offset) params.append("offset", offset); if (offset) params.append("offset", offset);
if (format) params.append("format", format); if (format) params.append("format", format);
@@ -624,23 +648,23 @@
case "contact": case "contact":
url += "contact"; url += "contact";
const contactQuery = const contactKeyword =
document.getElementById("contact-query").value; document.getElementById("contact-keyword").value;
const contactFormat = const contactFormat =
document.getElementById("contact-format").value; document.getElementById("contact-format").value;
if (contactQuery) params.append("query", contactQuery); if (contactKeyword) params.append("keyword", contactKeyword);
if (contactFormat) params.append("format", contactFormat); if (contactFormat) params.append("format", contactFormat);
break; break;
case "chatroom": case "chatroom":
url += "chatroom"; url += "chatroom";
const chatroomQuery = const chatroomKeyword =
document.getElementById("chatroom-query").value; document.getElementById("chatroom-keyword").value;
const chatroomFormat = const chatroomFormat =
document.getElementById("chatroom-format").value; document.getElementById("chatroom-format").value;
if (chatroomQuery) params.append("query", chatroomQuery); if (chatroomKeyword) params.append("keyword", chatroomKeyword);
if (chatroomFormat) params.append("format", chatroomFormat); if (chatroomFormat) params.append("format", chatroomFormat);
break; break;

View File

@@ -21,12 +21,12 @@ var (
InputSchema: mcp.ToolSchema{ InputSchema: mcp.ToolSchema{
Type: "object", Type: "object",
Properties: mcp.M{ Properties: mcp.M{
"query": mcp.M{ "keyword": mcp.M{
"type": "string", "type": "string",
"description": "联系人的搜索关键词可以是姓名、备注名或ID。", "description": "联系人的搜索关键词可以是姓名、备注名或ID。",
}, },
}, },
Required: []string{"query"}, Required: []string{"keyword"},
}, },
} }
@@ -36,12 +36,12 @@ var (
InputSchema: mcp.ToolSchema{ InputSchema: mcp.ToolSchema{
Type: "object", Type: "object",
Properties: mcp.M{ Properties: mcp.M{
"query": mcp.M{ "keyword": mcp.M{
"type": "string", "type": "string",
"description": "群聊的搜索关键词可以是群名称、群ID或相关描述", "description": "群聊的搜索关键词可以是群名称、群ID或相关描述",
}, },
}, },
Required: []string{"query"}, Required: []string{"keyword"},
}, },
} }
@@ -55,24 +55,136 @@ var (
} }
ToolChatLog = mcp.Tool{ ToolChatLog = mcp.Tool{
Name: "chatlog", Name: "chatlog",
Description: "查询特定时间或时间段内与特定联系人或群组的聊天记录。当用户需要回顾过去的对话内容、查找特定信息或想了解与某人/某群的历史交流时使用此工具。", Description: `检索历史聊天记录,可根据时间、对话方、发送者和关键词等条件进行精确查询。当用户需要查找特定信息或想了解与某人/某群的历史交流时使用此工具。
【强制多步查询流程!】
当查询特定话题或特定发送者发言时,必须严格按照以下流程使用,任何偏离都会导致错误的结果:
步骤1: 初步定位相关消息
- 使用keyword参数查找特定话题
- 使用sender参数查找特定发送者的消息
- 使用较宽时间范围初步查询
步骤2: 【必须执行】针对每个关键结果点分别获取上下文
- 必须对步骤1返回的每个时间点T1, T2, T3...分别执行独立查询(时间范围接近的消息可以合并为一个查询)
- 每次独立查询必须移除keyword参数
- 每次独立查询必须移除sender参数
- 每次独立查询使用"Tn前后15-30分钟"的窄范围
- 每次独立查询仅保留talker参数
步骤3: 【必须执行】综合分析所有上下文
- 必须等待所有步骤2的查询结果返回后再进行分析
- 必须综合考虑所有上下文信息后再回答用户
【严格执行规则!】
- 禁止仅凭步骤1的结果直接回答用户
- 禁止在步骤2使用过大的时间范围一次性查询所有上下文
- 禁止跳过步骤2或步骤3
- 必须对每个关键结果点分别执行独立的上下文查询
【执行示例】
正确流程示例:
1. 步骤1: chatlog(time="2023-04-01~2023-04-30", talker="工作群", keyword="项目进度")
返回结果: 4月5日、4月12日、4月20日有相关消息
2. 步骤2:
- 查询1: chatlog(time="2023-04-05/09:30~2023-04-05/10:30", talker="工作群") // 注意没有keyword
- 查询2: chatlog(time="2023-04-12/14:00~2023-04-12/15:00", talker="工作群") // 注意没有keyword
- 查询3: chatlog(time="2023-04-20/16:00~2023-04-20/17:00", talker="工作群") // 注意没有keyword
3. 步骤3: 综合分析所有上下文后回答用户
错误流程示例:
- 仅执行步骤1后直接回答
- 步骤2使用time="2023-04-01~2023-04-30"一次性查询
- 步骤2仍然保留keyword或sender参数
【自我检查】回答用户前必须自问:
- 我是否对每个关键时间点都执行了独立的上下文查询?
- 我是否在上下文查询中移除了keyword和sender参数?
- 我是否分析了所有上下文后再回答?
- 如果上述任一问题答案为"否",则必须纠正流程
返回格式:"昵称(ID) 时间\n消息内容\n昵称(ID) 时间\n消息内容"
当查询多个Talker时返回格式为"昵称(ID)\n[TalkerName(Talker)] 时间\n消息内容"
重要提示:
1. 当用户询问特定时间段内的聊天记录时,必须使用正确的时间格式,特别是包含小时和分钟的查询
2. 对于"今天下午4点到5点聊了啥"这类查询,正确的时间参数格式应为"2023-04-18/16:00~2023-04-18/17:00"
3. 当用户询问具体群聊中某人的聊天记录时,使用"sender"参数
4. 当用户询问包含特定关键词的聊天记录时,使用"keyword"参数`,
InputSchema: mcp.ToolSchema{ InputSchema: mcp.ToolSchema{
Type: "object", Type: "object",
Properties: mcp.M{ Properties: mcp.M{
"time": mcp.M{ "time": mcp.M{
"type": "string", "type": "string",
"description": "查询的时间点或时间段。可以是具体时间,例如 YYYY-MM-DD也可以是时间段例如 YYYY-MM-DD~YYYY-MM-DD时间段之间用\"~\"分隔。", "description": `指定查询的时间点或时间范围,格式必须严格遵循以下规则:
【单一时间点格式】
- 精确到日:"2023-04-18"或"20230418"
- 精确到分钟(必须包含斜杠和冒号):"2023-04-18/14:30"或"20230418/14:30"表示2023年4月18日14点30分
【时间范围格式】(使用"~"分隔起止时间)
- 日期范围:"2023-04-01~2023-04-18"
- 同一天的时间段:"2023-04-18/14:30~2023-04-18/15:45"
* 表示2023年4月18日14点30分到15点45分之间
【重要提示】包含小时分钟的格式必须使用斜杠和冒号:"/"和":"
正确示例:"2023-04-18/16:30"4月18日下午4点30分
错误示例:"2023-04-18 16:30"、"2023-04-18T16:30"
【其他支持的格式】
- 年份:"2023"
- 月份:"2023-04"或"202304"`,
}, },
"talker": mcp.M{ "talker": mcp.M{
"type": "string", "type": "string",
"description": "交谈对象可以是联系人或群聊。支持使用ID、昵称、备注名等进行查询。", "description": `指定对话方(联系人或群组)
- 可使用ID、昵称或备注名
- 多个对话方用","分隔,如:"张三,李四,工作群"
- 【重要】这是多步查询中唯一应保留的参数`,
},
"sender": mcp.M{
"type": "string",
"description": `指定群聊中的发送者
- 仅在查询群聊记录时有效
- 多个发送者用","分隔,如:"张三,李四"
- 可使用ID、昵称或备注名
【重要】查询特定发送者的消息时:
1. 第一步使用sender参数初步定位多个相关消息时间点
2. 后续步骤必须移除sender参数分别查询每个时间点前后的完整对话
3. 错误示例:对所有找到的消息一次性查询大范围上下文
4. 正确示例对每个时间点T分别执行查询"T前后15-30分钟"不带sender`,
},
"keyword": mcp.M{
"type": "string",
"description": `搜索内容中的关键词
- 支持正则表达式匹配
- 【重要】查询特定话题时:
1. 第一步使用keyword参数初步定位多个相关消息时间点
2. 后续步骤必须移除keyword参数分别查询每个时间点前后的完整对话
3. 错误示例:对所有找到的关键词消息一次性查询大范围上下文
4. 正确示例对每个时间点T分别执行查询"T前后15-30分钟"不带keyword`,
}, },
}, },
Required: []string{"time", "talker"}, Required: []string{"time", "talker"},
}, },
} }
ToolCurrentTime = mcp.Tool{
Name: "current_time",
Description: `获取当前系统时间返回RFC3339格式的时间字符串包含用户本地时区信息
使用场景:
- 当用户询问"总结今日聊天记录"、"本周都聊了啥"等当前时间问题
- 当用户提及"昨天"、"上周"、"本月"等相对时间概念,需要确定基准时间点
- 需要执行依赖当前时间的计算(如"上个月5号我们有开会吗"
返回示例2025-04-18T21:29:00+08:00
注意:此工具不需要任何输入参数,直接调用即可获取当前时间。`,
InputSchema: mcp.ToolSchema{
Type: "object",
Properties: mcp.M{},
},
}
ResourceRecentChat = mcp.Resource{ ResourceRecentChat = mcp.Resource{
Name: "最近会话", Name: "最近会话",
URI: "session://recent", URI: "session://recent",

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"strings" "strings"
"time"
"github.com/sjzar/chatlog/internal/chatlog/ctx" "github.com/sjzar/chatlog/internal/chatlog/ctx"
"github.com/sjzar/chatlog/internal/chatlog/database" "github.com/sjzar/chatlog/internal/chatlog/database"
@@ -83,6 +84,7 @@ func (s *Service) processMCP(session *mcp.Session, req *mcp.Request) {
ToolChatRoom, ToolChatRoom,
ToolRecentChat, ToolRecentChat,
ToolChatLog, ToolChatLog,
ToolCurrentTime,
}}) }})
case mcp.MethodToolsCall: case mcp.MethodToolsCall:
err = s.toolsCall(session, req) err = s.toolsCall(session, req)
@@ -130,13 +132,13 @@ func (s *Service) toolsCall(session *mcp.Session, req *mcp.Request) error {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
switch callReq.Name { switch callReq.Name {
case "query_contact": case "query_contact":
query := "" keyword := ""
if v, ok := callReq.Arguments["query"]; ok { if v, ok := callReq.Arguments["keyword"]; ok {
query = v.(string) keyword = v.(string)
} }
limit := util.MustAnyToInt(callReq.Arguments["limit"]) limit := util.MustAnyToInt(callReq.Arguments["limit"])
offset := util.MustAnyToInt(callReq.Arguments["offset"]) offset := util.MustAnyToInt(callReq.Arguments["offset"])
list, err := s.db.GetContacts(query, limit, offset) list, err := s.db.GetContacts(keyword, limit, offset)
if err != nil { if err != nil {
return fmt.Errorf("无法获取联系人列表: %v", err) return fmt.Errorf("无法获取联系人列表: %v", err)
} }
@@ -145,13 +147,13 @@ func (s *Service) toolsCall(session *mcp.Session, req *mcp.Request) error {
buf.WriteString(fmt.Sprintf("%s,%s,%s,%s\n", contact.UserName, contact.Alias, contact.Remark, contact.NickName)) buf.WriteString(fmt.Sprintf("%s,%s,%s,%s\n", contact.UserName, contact.Alias, contact.Remark, contact.NickName))
} }
case "query_chat_room": case "query_chat_room":
query := "" keyword := ""
if v, ok := callReq.Arguments["query"]; ok { if v, ok := callReq.Arguments["keyword"]; ok {
query = v.(string) keyword = v.(string)
} }
limit := util.MustAnyToInt(callReq.Arguments["limit"]) limit := util.MustAnyToInt(callReq.Arguments["limit"])
offset := util.MustAnyToInt(callReq.Arguments["offset"]) offset := util.MustAnyToInt(callReq.Arguments["offset"])
list, err := s.db.GetChatRooms(query, limit, offset) list, err := s.db.GetChatRooms(keyword, limit, offset)
if err != nil { if err != nil {
return fmt.Errorf("无法获取群聊列表: %v", err) return fmt.Errorf("无法获取群聊列表: %v", err)
} }
@@ -160,13 +162,13 @@ func (s *Service) toolsCall(session *mcp.Session, req *mcp.Request) error {
buf.WriteString(fmt.Sprintf("%s,%s,%s,%s,%d\n", chatRoom.Name, chatRoom.Remark, chatRoom.NickName, chatRoom.Owner, len(chatRoom.Users))) buf.WriteString(fmt.Sprintf("%s,%s,%s,%s,%d\n", chatRoom.Name, chatRoom.Remark, chatRoom.NickName, chatRoom.Owner, len(chatRoom.Users)))
} }
case "query_recent_chat": case "query_recent_chat":
query := "" keyword := ""
if v, ok := callReq.Arguments["query"]; ok { if v, ok := callReq.Arguments["keyword"]; ok {
query = v.(string) keyword = v.(string)
} }
limit := util.MustAnyToInt(callReq.Arguments["limit"]) limit := util.MustAnyToInt(callReq.Arguments["limit"])
offset := util.MustAnyToInt(callReq.Arguments["offset"]) offset := util.MustAnyToInt(callReq.Arguments["offset"])
data, err := s.db.GetSessions(query, limit, offset) data, err := s.db.GetSessions(keyword, limit, offset)
if err != nil { if err != nil {
return fmt.Errorf("无法获取会话列表: %v", err) return fmt.Errorf("无法获取会话列表: %v", err)
} }
@@ -190,16 +192,29 @@ func (s *Service) toolsCall(session *mcp.Session, req *mcp.Request) error {
if v, ok := callReq.Arguments["talker"]; ok { if v, ok := callReq.Arguments["talker"]; ok {
talker = v.(string) talker = v.(string)
} }
sender := ""
if v, ok := callReq.Arguments["sender"]; ok {
sender = v.(string)
}
keyword := ""
if v, ok := callReq.Arguments["keyword"]; ok {
keyword = v.(string)
}
limit := util.MustAnyToInt(callReq.Arguments["limit"]) limit := util.MustAnyToInt(callReq.Arguments["limit"])
offset := util.MustAnyToInt(callReq.Arguments["offset"]) offset := util.MustAnyToInt(callReq.Arguments["offset"])
messages, err := s.db.GetMessages(start, end, talker, limit, offset) messages, err := s.db.GetMessages(start, end, talker, sender, keyword, limit, offset)
if err != nil { if err != nil {
return fmt.Errorf("无法获取聊天记录: %v", err) return fmt.Errorf("无法获取聊天记录: %v", err)
} }
if len(messages) == 0 {
buf.WriteString("未找到符合查询条件的聊天记录")
}
for _, m := range messages { for _, m := range messages {
buf.WriteString(m.PlainText(len(talker) == 0, "")) buf.WriteString(m.PlainText(strings.Contains(talker, ","), util.PerfectTimeFormat(start, end), ""))
buf.WriteString("\n") buf.WriteString("\n")
} }
case "current_time":
buf.WriteString(time.Now().Local().Format(time.RFC3339))
default: default:
return fmt.Errorf("未支持的工具: %s", callReq.Name) return fmt.Errorf("未支持的工具: %s", callReq.Name)
} }
@@ -228,7 +243,6 @@ func (s *Service) resourcesRead(session *mcp.Session, req *mcp.Request) error {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
switch u.Scheme { switch u.Scheme {
case "contact": case "contact":
list, err := s.db.GetContacts(u.Host, 0, 0) list, err := s.db.GetContacts(u.Host, 0, 0)
if err != nil { if err != nil {
return fmt.Errorf("无法获取联系人列表: %v", err) return fmt.Errorf("无法获取联系人列表: %v", err)
@@ -262,12 +276,15 @@ func (s *Service) resourcesRead(session *mcp.Session, req *mcp.Request) error {
} }
limit := util.MustAnyToInt(u.Query().Get("limit")) limit := util.MustAnyToInt(u.Query().Get("limit"))
offset := util.MustAnyToInt(u.Query().Get("offset")) offset := util.MustAnyToInt(u.Query().Get("offset"))
messages, err := s.db.GetMessages(start, end, u.Host, limit, offset) messages, err := s.db.GetMessages(start, end, u.Host, "", "", limit, offset)
if err != nil { if err != nil {
return fmt.Errorf("无法获取聊天记录: %v", err) return fmt.Errorf("无法获取聊天记录: %v", err)
} }
if len(messages) == 0 {
buf.WriteString("未找到符合查询条件的聊天记录")
}
for _, m := range messages { for _, m := range messages {
buf.WriteString(m.PlainText(len(u.Host) == 0, "")) buf.WriteString(m.PlainText(strings.Contains(u.Host, ","), util.PerfectTimeFormat(start, end), ""))
buf.WriteString("\n") buf.WriteString("\n")
} }
default: default:

View File

@@ -65,7 +65,7 @@ func Newf(cause error, code int, format string, args ...interface{}) *Error {
return &Error{ return &Error{
Message: fmt.Sprintf(format, args...), Message: fmt.Sprintf(format, args...),
Cause: cause, Cause: cause,
Code: http.StatusInternalServerError, Code: code,
} }
} }

View File

@@ -2,6 +2,7 @@ package errors
import ( import (
"net/http" "net/http"
"runtime/debug"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
@@ -50,7 +51,7 @@ func RecoveryMiddleware() gin.HandlerFunc {
} }
// 记录错误日志 // 记录错误日志
log.Err(err).Msg("PANIC RECOVERED") log.Err(err).Msgf("PANIC RECOVERED\n%s", string(debug.Stack()))
// 返回 500 错误 // 返回 500 错误
c.JSON(http.StatusInternalServerError, err) c.JSON(http.StatusInternalServerError, err)

View File

@@ -35,6 +35,7 @@ type Image struct {
} }
type Video struct { type Video struct {
Md5 string `xml:"md5,attr"`
RawMd5 string `xml:"rawmd5,attr"` RawMd5 string `xml:"rawmd5,attr"`
// Length string `xml:"length,attr"` // Length string `xml:"length,attr"`
// PlayLength string `xml:"playlength,attr"` // PlayLength string `xml:"playlength,attr"`

View File

@@ -79,7 +79,12 @@ func (m *Message) ParseMediaInfo(data string) error {
case 3: case 3:
m.Contents["md5"] = msg.Image.MD5 m.Contents["md5"] = msg.Image.MD5
case 43: case 43:
m.Contents["md5"] = msg.Video.RawMd5 if msg.Video.Md5 != "" {
m.Contents["md5"] = msg.Video.Md5
}
if msg.Video.RawMd5 != "" {
m.Contents["rawmd5"] = msg.Video.RawMd5
}
case 49: case 49:
m.SubType = int64(msg.App.Type) m.SubType = int64(msg.App.Type)
switch m.SubType { switch m.SubType {
@@ -183,7 +188,11 @@ func (m *Message) SetContent(key string, value interface{}) {
m.Contents[key] = value m.Contents[key] = value
} }
func (m *Message) PlainText(showChatRoom bool, host string) string { func (m *Message) PlainText(showChatRoom bool, timeFormat string, host string) string {
if timeFormat == "" {
timeFormat = "01-02 15:04:05"
}
m.SetContent("host", host) m.SetContent("host", host)
@@ -216,7 +225,7 @@ func (m *Message) PlainText(showChatRoom bool, host string) string {
buf.WriteString("] ") buf.WriteString("] ")
} }
buf.WriteString(m.Time.Format("2006-01-02 15:04:05")) buf.WriteString(m.Time.Format(timeFormat))
buf.WriteString("\n") buf.WriteString("\n")
buf.WriteString(m.PlainTextContent()) buf.WriteString(m.PlainTextContent())
@@ -230,7 +239,23 @@ func (m *Message) PlainTextContent() string {
case 1: case 1:
return m.Content return m.Content
case 3: case 3:
return fmt.Sprintf("![图片](http://%s/image/%s)", m.Contents["host"], m.Contents["md5"]) keylist := make([]string, 0)
if m.Contents["md5"] != nil {
if md5, ok := m.Contents["md5"].(string); ok {
keylist = append(keylist, md5)
}
}
if m.Contents["imgfile"] != nil {
if imgfile, ok := m.Contents["imgfile"].(string); ok {
keylist = append(keylist, imgfile)
}
}
if m.Contents["thumb"] != nil {
if thumb, ok := m.Contents["thumb"].(string); ok {
keylist = append(keylist, thumb)
}
}
return fmt.Sprintf("![图片](http://%s/image/%s)", m.Contents["host"], strings.Join(keylist, ","))
case 34: case 34:
if voice, ok := m.Contents["voice"]; ok { if voice, ok := m.Contents["voice"]; ok {
return fmt.Sprintf("[语音](http://%s/voice/%s)", m.Contents["host"], voice) return fmt.Sprintf("[语音](http://%s/voice/%s)", m.Contents["host"], voice)
@@ -239,10 +264,28 @@ func (m *Message) PlainTextContent() string {
case 42: case 42:
return "[名片]" return "[名片]"
case 43: case 43:
if path, ok := m.Contents["path"]; ok { keylist := make([]string, 0)
return fmt.Sprintf("![视频](http://%s/data/%s)", m.Contents["host"], path) if m.Contents["md5"] != nil {
if md5, ok := m.Contents["md5"].(string); ok {
keylist = append(keylist, md5)
}
} }
return fmt.Sprintf("![视频](http://%s/video/%s)", m.Contents["host"], m.Contents["md5"]) if m.Contents["rawmd5"] != nil {
if rawmd5, ok := m.Contents["rawmd5"].(string); ok {
keylist = append(keylist, rawmd5)
}
}
if m.Contents["videofile"] != nil {
if videofile, ok := m.Contents["videofile"].(string); ok {
keylist = append(keylist, videofile)
}
}
if m.Contents["thumb"] != nil {
if thumb, ok := m.Contents["thumb"].(string); ok {
keylist = append(keylist, thumb)
}
}
return fmt.Sprintf("![视频](http://%s/video/%s)", m.Contents["host"], strings.Join(keylist, ","))
case 47: case 47:
return "[动画表情]" return "[动画表情]"
case 49: case 49:
@@ -262,7 +305,11 @@ func (m *Message) PlainTextContent() string {
if !ok { if !ok {
return "[合并转发]" return "[合并转发]"
} }
return recordInfo.String("", m.Contents["host"].(string)) host := ""
if m.Contents["host"] != nil {
host = m.Contents["host"].(string)
}
return recordInfo.String("", host)
case 33, 36: case 33, 36:
if m.Contents["title"] == "" { if m.Contents["title"] == "" {
return "[小程序]" return "[小程序]"
@@ -290,7 +337,11 @@ func (m *Message) PlainTextContent() string {
return "> [引用]\n" + m.Content return "> [引用]\n" + m.Content
} }
buf := strings.Builder{} buf := strings.Builder{}
referContent := refer.PlainText(false, m.Contents["host"].(string)) host := ""
if m.Contents["host"] != nil {
host = m.Contents["host"].(string)
}
referContent := refer.PlainText(false, "", host)
for _, line := range strings.Split(referContent, "\n") { for _, line := range strings.Split(referContent, "\n") {
if line == "" { if line == "" {
continue continue

View File

@@ -96,7 +96,7 @@ func (m *MessageV3) Wrap() *Message {
if len(parts) > 1 { if len(parts) > 1 {
path = strings.Join(parts[1:], "/") path = strings.Join(parts[1:], "/")
} }
_m.Contents["path"] = path _m.Contents["videofile"] = path
} }
} }
} }

View File

@@ -2,7 +2,10 @@ package model
import ( import (
"bytes" "bytes"
"crypto/md5"
"encoding/hex"
"fmt" "fmt"
"path/filepath"
"strings" "strings"
"time" "time"
@@ -85,10 +88,14 @@ func (m *MessageV4) Wrap(talker string) *Message {
if packedInfo := ParsePackedInfo(m.PackedInfoData); packedInfo != nil { if packedInfo := ParsePackedInfo(m.PackedInfoData); packedInfo != nil {
// FIXME 尝试解决 v4 版本 xml 数据无法匹配到 hardlink 记录的问题 // FIXME 尝试解决 v4 版本 xml 数据无法匹配到 hardlink 记录的问题
if _m.Type == 3 && packedInfo.Image != nil { if _m.Type == 3 && packedInfo.Image != nil {
_m.Contents["md5"] = packedInfo.Image.Md5 _talkerMd5Bytes := md5.Sum([]byte(talker))
talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:])
_m.Contents["imgfile"] = filepath.Join("msg", "attach", talkerMd5, _m.Time.Format("2006-01"), "Img", fmt.Sprintf("%s.dat", packedInfo.Image.Md5))
_m.Contents["thumb"] = filepath.Join("msg", "attach", talkerMd5, _m.Time.Format("2006-01"), "Img", fmt.Sprintf("%s_t.dat", packedInfo.Image.Md5))
} }
if _m.Type == 43 && packedInfo.Video != nil { if _m.Type == 43 && packedInfo.Video != nil {
_m.Contents["md5"] = packedInfo.Video.Md5 _m.Contents["videofile"] = filepath.Join("msg", "video", _m.Time.Format("2006-01"), fmt.Sprintf("%s.mp4", packedInfo.Video.Md5))
_m.Contents["thumb"] = filepath.Join("msg", "video", _m.Time.Format("2006-01"), fmt.Sprintf("%s_thumb.jpg", packedInfo.Video.Md5))
} }
} }
} }

View File

@@ -5,6 +5,8 @@ import (
"crypto/md5" "crypto/md5"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"regexp"
"sort"
"strings" "strings"
"time" "time"
@@ -15,6 +17,7 @@ import (
"github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/errors"
"github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/internal/model"
"github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm" "github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm"
"github.com/sjzar/chatlog/pkg/util"
) )
const ( const (
@@ -25,7 +28,7 @@ const (
Media = "media" Media = "media"
) )
var Groups = []dbm.Group{ var Groups = []*dbm.Group{
{ {
Name: Message, Name: Message,
Pattern: `^msg_([0-9]?[0-9])?\.db$`, Pattern: `^msg_([0-9]?[0-9])?\.db$`,
@@ -114,6 +117,10 @@ func (ds *DataSource) initMessageDbs() error {
dbPaths, err := ds.dbm.GetDBPath(Message) dbPaths, err := ds.dbm.GetDBPath(Message)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "db file not found") {
ds.talkerDBMap = make(map[string]string)
return nil
}
return err return err
} }
// 处理每个数据库文件 // 处理每个数据库文件
@@ -155,6 +162,10 @@ func (ds *DataSource) initMessageDbs() error {
func (ds *DataSource) initChatRoomDb() error { func (ds *DataSource) initChatRoomDb() error {
db, err := ds.dbm.GetDB(ChatRoom) db, err := ds.dbm.GetDB(ChatRoom)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "db file not found") {
ds.user2DisplayName = make(map[string]string)
return nil
}
return err return err
} }
@@ -180,70 +191,162 @@ func (ds *DataSource) initChatRoomDb() error {
return nil return nil
} }
// GetMessages 实现获取消息的方法 func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) {
func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) {
// 在 darwinv3 中,每个联系人/群聊的消息存储在单独的表中,表名为 Chat_md5(talker)
// 首先需要找到对应的表名
if talker == "" { if talker == "" {
return nil, errors.ErrTalkerEmpty return nil, errors.ErrTalkerEmpty
} }
_talkerMd5Bytes := md5.Sum([]byte(talker)) // 解析talker参数支持多个talker以英文逗号分隔
talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:]) talkers := util.Str2List(talker, ",")
dbPath, ok := ds.talkerDBMap[talkerMd5] if len(talkers) == 0 {
if !ok { return nil, errors.ErrTalkerEmpty
return nil, errors.TalkerNotFound(talker)
} }
db, err := ds.dbm.OpenDB(dbPath)
if err != nil {
return nil, err
}
tableName := fmt.Sprintf("Chat_%s", talkerMd5)
// 构建查询条件 // 解析sender参数支持多个发送者以英文逗号分隔
query := fmt.Sprintf(` senders := util.Str2List(sender, ",")
SELECT msgCreateTime, msgContent, messageType, mesDes
FROM %s
WHERE msgCreateTime >= ? AND msgCreateTime <= ?
ORDER BY msgCreateTime ASC
`, tableName)
if limit > 0 { // 预编译正则表达式如果有keyword
query += fmt.Sprintf(" LIMIT %d", limit) var regex *regexp.Regexp
if keyword != "" {
if offset > 0 { var err error
query += fmt.Sprintf(" OFFSET %d", offset) regex, err = regexp.Compile(keyword)
if err != nil {
return nil, errors.QueryFailed("invalid regex pattern", err)
} }
} }
// 执行查询 // 从每个相关数据库中查询消息,并在读取时进行过滤
rows, err := db.QueryContext(ctx, query, startTime.Unix(), endTime.Unix()) filteredMessages := []*model.Message{}
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
// 处理查询结果 // 对每个talker进行查询
messages := []*model.Message{} for _, talkerItem := range talkers {
for rows.Next() { // 检查上下文是否已取消
var msg model.MessageDarwinV3 if err := ctx.Err(); err != nil {
err := rows.Scan( return nil, err
&msg.MsgCreateTime, }
&msg.MsgContent,
&msg.MessageType, // 在 darwinv3 中,需要先找到对应的数据库
&msg.MesDes, _talkerMd5Bytes := md5.Sum([]byte(talkerItem))
) talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:])
if err != nil { dbPath, ok := ds.talkerDBMap[talkerMd5]
log.Err(err).Msgf("扫描消息行失败") if !ok {
// 如果找不到对应的数据库跳过此talker
continue continue
} }
// 将消息包装为通用模型 db, err := ds.dbm.OpenDB(dbPath)
message := msg.Wrap(talker) if err != nil {
messages = append(messages, message) log.Error().Msgf("数据库 %s 未打开", dbPath)
continue
}
tableName := fmt.Sprintf("Chat_%s", talkerMd5)
// 构建查询条件
query := fmt.Sprintf(`
SELECT msgCreateTime, msgContent, messageType, mesDes
FROM %s
WHERE msgCreateTime >= ? AND msgCreateTime <= ?
ORDER BY msgCreateTime ASC
`, tableName)
// 执行查询
rows, err := db.QueryContext(ctx, query, startTime.Unix(), endTime.Unix())
if err != nil {
// 如果表不存在跳过此talker
if strings.Contains(err.Error(), "no such table") {
continue
}
log.Err(err).Msgf("从数据库 %s 查询消息失败", dbPath)
continue
}
// 处理查询结果,在读取时进行过滤
for rows.Next() {
var msg model.MessageDarwinV3
err := rows.Scan(
&msg.MsgCreateTime,
&msg.MsgContent,
&msg.MessageType,
&msg.MesDes,
)
if err != nil {
rows.Close()
log.Err(err).Msgf("扫描消息行失败")
continue
}
// 将消息包装为通用模型
message := msg.Wrap(talkerItem)
// 应用sender过滤
if len(senders) > 0 {
senderMatch := false
for _, s := range senders {
if message.Sender == s {
senderMatch = true
break
}
}
if !senderMatch {
continue // 不匹配sender跳过此消息
}
}
// 应用keyword过滤
if regex != nil {
plainText := message.PlainTextContent()
if !regex.MatchString(plainText) {
continue // 不匹配keyword跳过此消息
}
}
// 通过所有过滤条件,保留此消息
filteredMessages = append(filteredMessages, message)
// 检查是否已经满足分页处理数量
if limit > 0 && len(filteredMessages) >= offset+limit {
// 已经获取了足够的消息,可以提前返回
rows.Close()
// 对所有消息按时间排序
sort.Slice(filteredMessages, func(i, j int) bool {
return filteredMessages[i].Seq < filteredMessages[j].Seq
})
// 处理分页
if offset >= len(filteredMessages) {
return []*model.Message{}, nil
}
end := offset + limit
if end > len(filteredMessages) {
end = len(filteredMessages)
}
return filteredMessages[offset:end], nil
}
}
rows.Close()
} }
return messages, nil // 对所有消息按时间排序
// FIXME 不同 talker 需要使用 Time 排序
sort.Slice(filteredMessages, func(i, j int) bool {
return filteredMessages[i].Time.Before(filteredMessages[j].Time)
})
// 处理分页
if limit > 0 {
if offset >= len(filteredMessages) {
return []*model.Message{}, nil
}
end := offset + limit
if end > len(filteredMessages) {
end = len(filteredMessages)
}
return filteredMessages[offset:end], nil
}
return filteredMessages, nil
} }
// 从表名中提取 talker // 从表名中提取 talker

View File

@@ -16,7 +16,7 @@ import (
type DataSource interface { type DataSource interface {
// 消息 // 消息
GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error)
// 联系人 // 联系人
GetContacts(ctx context.Context, key string, limit, offset int) ([]*model.Contact, error) GetContacts(ctx context.Context, key string, limit, offset int) ([]*model.Contact, error)

View File

@@ -34,7 +34,7 @@ func NewDBManager(path string) *DBManager {
} }
} }
func (d *DBManager) AddGroup(g Group) error { func (d *DBManager) AddGroup(g *Group) error {
fg, err := filemonitor.NewFileGroup(g.Name, d.path, g.Pattern, g.BlackList) fg, err := filemonitor.NewFileGroup(g.Name, d.path, g.Pattern, g.BlackList)
if err != nil { if err != nil {
return err return err

View File

@@ -9,7 +9,7 @@ import (
func TestXxx(t *testing.T) { func TestXxx(t *testing.T) {
path := "/Users/sarv/Documents/chatlog/bigjun_9e7a" path := "/Users/sarv/Documents/chatlog/bigjun_9e7a"
g := Group{ g := &Group{
Name: "session", Name: "session",
Pattern: `session\.db$`, Pattern: `session\.db$`,
BlackList: []string{}, BlackList: []string{},

View File

@@ -6,6 +6,7 @@ import (
"database/sql" "database/sql"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"regexp"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -17,6 +18,7 @@ import (
"github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/errors"
"github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/internal/model"
"github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm" "github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm"
"github.com/sjzar/chatlog/pkg/util"
) )
const ( const (
@@ -27,7 +29,7 @@ const (
Voice = "voice" Voice = "voice"
) )
var Groups = []dbm.Group{ var Groups = []*dbm.Group{
{ {
Name: Message, Name: Message,
Pattern: `^message_([0-9]?[0-9])?\.db$`, Pattern: `^message_([0-9]?[0-9])?\.db$`,
@@ -113,6 +115,10 @@ func (ds *DataSource) SetCallback(name string, callback func(event fsnotify.Even
func (ds *DataSource) initMessageDbs() error { func (ds *DataSource) initMessageDbs() error {
dbPaths, err := ds.dbm.GetDBPath(Message) dbPaths, err := ds.dbm.GetDBPath(Message)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "db file not found") {
ds.messageInfos = make([]MessageDBInfo, 0)
return nil
}
return err return err
} }
@@ -171,11 +177,16 @@ func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []Mes
return dbs return dbs
} }
func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) {
if talker == "" { if talker == "" {
return nil, errors.ErrTalkerEmpty return nil, errors.ErrTalkerEmpty
} }
log.Debug().Msg(talker)
// 解析talker参数支持多个talker以英文逗号分隔
talkers := util.Str2List(talker, ",")
if len(talkers) == 0 {
return nil, errors.ErrTalkerEmpty
}
// 找到时间范围内的数据库文件 // 找到时间范围内的数据库文件
dbInfos := ds.getDBInfosForTimeRange(startTime, endTime) dbInfos := ds.getDBInfosForTimeRange(startTime, endTime)
@@ -183,13 +194,21 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
return nil, errors.TimeRangeNotFound(startTime, endTime) return nil, errors.TimeRangeNotFound(startTime, endTime)
} }
if len(dbInfos) == 1 { // 解析sender参数支持多个发送者以英文逗号分隔
// LIMIT 和 OFFSET 逻辑在单文件情况下可以直接在 SQL 里处理 senders := util.Str2List(sender, ",")
return ds.getMessagesSingleFile(ctx, dbInfos[0], startTime, endTime, talker, limit, offset)
// 预编译正则表达式如果有keyword
var regex *regexp.Regexp
if keyword != "" {
var err error
regex, err = regexp.Compile(keyword)
if err != nil {
return nil, errors.QueryFailed("invalid regex pattern", err)
}
} }
// 从每个相关数据库中查询消息 // 从每个相关数据库中查询消息,并在读取时进行过滤
totalMessages := []*model.Message{} filteredMessages := []*model.Message{}
for _, dbInfo := range dbInfos { for _, dbInfo := range dbInfos {
// 检查上下文是否已取消 // 检查上下文是否已取消
@@ -203,183 +222,141 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
continue continue
} }
messages, err := ds.getMessagesFromDB(ctx, db, startTime, endTime, talker) // 对每个talker进行查询
if err != nil { for _, talkerItem := range talkers {
log.Err(err).Msgf("从数据库 %s 获取消息失败", dbInfo.FilePath) // 构建表名
continue _talkerMd5Bytes := md5.Sum([]byte(talkerItem))
} talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:])
tableName := "Msg_" + talkerMd5
totalMessages = append(totalMessages, messages...) // 检查表是否存在
var exists bool
err = db.QueryRowContext(ctx,
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
tableName).Scan(&exists)
if limit+offset > 0 && len(totalMessages) >= limit+offset { if err != nil {
break if err == sql.ErrNoRows {
// 表不存在继续下一个talker
continue
}
return nil, errors.QueryFailed("", err)
}
// 构建查询条件
conditions := []string{"create_time >= ? AND create_time <= ?"}
args := []interface{}{startTime.Unix(), endTime.Unix()}
log.Debug().Msgf("Table name: %s", tableName)
log.Debug().Msgf("Start time: %d, End time: %d", startTime.Unix(), endTime.Unix())
query := fmt.Sprintf(`
SELECT m.sort_seq, m.server_id, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status
FROM %s m
LEFT JOIN Name2Id n ON m.real_sender_id = n.rowid
WHERE %s
ORDER BY m.sort_seq ASC
`, tableName, strings.Join(conditions, " AND "))
// 执行查询
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
// 如果表不存在SQLite 会返回错误
if strings.Contains(err.Error(), "no such table") {
continue
}
log.Err(err).Msgf("从数据库 %s 查询消息失败", dbInfo.FilePath)
continue
}
// 处理查询结果,在读取时进行过滤
for rows.Next() {
var msg model.MessageV4
err := rows.Scan(
&msg.SortSeq,
&msg.ServerID,
&msg.LocalType,
&msg.UserName,
&msg.CreateTime,
&msg.MessageContent,
&msg.PackedInfoData,
&msg.Status,
)
if err != nil {
rows.Close()
return nil, errors.ScanRowFailed(err)
}
// 将消息转换为标准格式
message := msg.Wrap(talkerItem)
// 应用sender过滤
if len(senders) > 0 {
senderMatch := false
for _, s := range senders {
if message.Sender == s {
senderMatch = true
break
}
}
if !senderMatch {
continue // 不匹配sender跳过此消息
}
}
// 应用keyword过滤
if regex != nil {
plainText := message.PlainTextContent()
if !regex.MatchString(plainText) {
continue // 不匹配keyword跳过此消息
}
}
// 通过所有过滤条件,保留此消息
filteredMessages = append(filteredMessages, message)
// 检查是否已经满足分页处理数量
if limit > 0 && len(filteredMessages) >= offset+limit {
// 已经获取了足够的消息,可以提前返回
rows.Close()
// 对所有消息按时间排序
sort.Slice(filteredMessages, func(i, j int) bool {
return filteredMessages[i].Seq < filteredMessages[j].Seq
})
// 处理分页
if offset >= len(filteredMessages) {
return []*model.Message{}, nil
}
end := offset + limit
if end > len(filteredMessages) {
end = len(filteredMessages)
}
return filteredMessages[offset:end], nil
}
}
rows.Close()
} }
} }
// 对所有消息按时间排序 // 对所有消息按时间排序
sort.Slice(totalMessages, func(i, j int) bool { sort.Slice(filteredMessages, func(i, j int) bool {
return totalMessages[i].Seq < totalMessages[j].Seq return filteredMessages[i].Seq < filteredMessages[j].Seq
}) })
// 处理分页 // 处理分页
if limit > 0 { if limit > 0 {
if offset >= len(totalMessages) { if offset >= len(filteredMessages) {
return []*model.Message{}, nil return []*model.Message{}, nil
} }
end := offset + limit end := offset + limit
if end > len(totalMessages) { if end > len(filteredMessages) {
end = len(totalMessages) end = len(filteredMessages)
} }
return totalMessages[offset:end], nil return filteredMessages[offset:end], nil
} }
return totalMessages, nil return filteredMessages, nil
}
// getMessagesSingleFile 从单个数据库文件获取消息
func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) {
db, err := ds.dbm.OpenDB(dbInfo.FilePath)
if err != nil {
return nil, errors.DBConnectFailed(dbInfo.FilePath, nil)
}
// 构建表名
_talkerMd5Bytes := md5.Sum([]byte(talker))
talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:])
tableName := "Msg_" + talkerMd5
// 检查表是否存在
var exists bool
err = db.QueryRowContext(ctx,
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
tableName).Scan(&exists)
if err != nil {
if err == sql.ErrNoRows {
// 表不存在,返回空结果
return []*model.Message{}, nil
}
return nil, errors.QueryFailed("", err)
}
// 构建查询条件
conditions := []string{"create_time >= ? AND create_time <= ?"}
args := []interface{}{startTime.Unix(), endTime.Unix()}
query := fmt.Sprintf(`
SELECT m.sort_seq, m.server_id, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status
FROM %s m
LEFT JOIN Name2Id n ON m.real_sender_id = n.rowid
WHERE %s
ORDER BY m.sort_seq ASC
`, tableName, strings.Join(conditions, " AND "))
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
if offset > 0 {
query += fmt.Sprintf(" OFFSET %d", offset)
}
}
// 执行查询
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
// 处理查询结果
messages := []*model.Message{}
for rows.Next() {
var msg model.MessageV4
err := rows.Scan(
&msg.SortSeq,
&msg.ServerID,
&msg.LocalType,
&msg.UserName,
&msg.CreateTime,
&msg.MessageContent,
&msg.PackedInfoData,
&msg.Status,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
messages = append(messages, msg.Wrap(talker))
}
return messages, nil
}
// getMessagesFromDB 从数据库获取消息
func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, startTime, endTime time.Time, talker string) ([]*model.Message, error) {
// 构建表名
_talkerMd5Bytes := md5.Sum([]byte(talker))
talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:])
tableName := "Msg_" + talkerMd5
// 检查表是否存在
var exists bool
err := db.QueryRowContext(ctx,
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
tableName).Scan(&exists)
if err != nil {
if err == sql.ErrNoRows {
// 表不存在,返回空结果
return []*model.Message{}, nil
}
return nil, errors.QueryFailed("", err)
}
// 构建查询条件
conditions := []string{"create_time >= ? AND create_time <= ?"}
args := []interface{}{startTime.Unix(), endTime.Unix()}
query := fmt.Sprintf(`
SELECT m.sort_seq, m.server_id, m.local_type, n.user_name, m.create_time, m.message_content, m.packed_info_data, m.status
FROM %s m
LEFT JOIN Name2Id n ON m.real_sender_id = n.rowid
WHERE %s
ORDER BY m.sort_seq ASC
`, tableName, strings.Join(conditions, " AND "))
// 执行查询
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
// 如果表不存在SQLite 会返回错误
if strings.Contains(err.Error(), "no such table") {
return []*model.Message{}, nil
}
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
// 处理查询结果
messages := []*model.Message{}
for rows.Next() {
var msg model.MessageV4
err := rows.Scan(
&msg.SortSeq,
&msg.ServerID,
&msg.LocalType,
&msg.UserName,
&msg.CreateTime,
&msg.MessageContent,
&msg.PackedInfoData,
&msg.Status,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
messages = append(messages, msg.Wrap(talker))
}
return messages, nil
} }
// 联系人 // 联系人

View File

@@ -2,9 +2,9 @@ package windowsv3
import ( import (
"context" "context"
"database/sql"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"regexp"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -16,6 +16,7 @@ import (
"github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/errors"
"github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/internal/model"
"github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm" "github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm"
"github.com/sjzar/chatlog/pkg/util"
) )
const ( const (
@@ -27,7 +28,7 @@ const (
Voice = "voice" Voice = "voice"
) )
var Groups = []dbm.Group{ var Groups = []*dbm.Group{
{ {
Name: Message, Name: Message,
Pattern: `^MSG([0-9]?[0-9])?\.db$`, Pattern: `^MSG([0-9]?[0-9])?\.db$`,
@@ -122,6 +123,10 @@ func (ds *DataSource) initMessageDbs() error {
// 获取所有消息数据库文件路径 // 获取所有消息数据库文件路径
dbPaths, err := ds.dbm.GetDBPath(Message) dbPaths, err := ds.dbm.GetDBPath(Message)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "db file not found") {
ds.messageInfos = make([]MessageDBInfo, 0)
return nil
}
return err return err
} }
@@ -217,21 +222,38 @@ func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []Mes
return dbs return dbs
} }
// GetMessages 实现 DataSource 接口的 GetMessages 方法 func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) {
func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { if talker == "" {
return nil, errors.ErrTalkerEmpty
}
// 解析talker参数支持多个talker以英文逗号分隔
talkers := util.Str2List(talker, ",")
if len(talkers) == 0 {
return nil, errors.ErrTalkerEmpty
}
// 找到时间范围内的数据库文件 // 找到时间范围内的数据库文件
dbInfos := ds.getDBInfosForTimeRange(startTime, endTime) dbInfos := ds.getDBInfosForTimeRange(startTime, endTime)
if len(dbInfos) == 0 { if len(dbInfos) == 0 {
return nil, errors.TimeRangeNotFound(startTime, endTime) return nil, errors.TimeRangeNotFound(startTime, endTime)
} }
if len(dbInfos) == 1 { // 解析sender参数支持多个发送者以英文逗号分隔
// LIMIT 和 OFFSET 逻辑在单文件情况下可以直接在 SQL 里处理 senders := util.Str2List(sender, ",")
return ds.getMessagesSingleFile(ctx, dbInfos[0], startTime, endTime, talker, limit, offset)
// 预编译正则表达式如果有keyword
var regex *regexp.Regexp
if keyword != "" {
var err error
regex, err = regexp.Compile(keyword)
if err != nil {
return nil, errors.QueryFailed("invalid regex pattern", err)
}
} }
// 从每个相关数据库中查询消息 // 从每个相关数据库中查询消息
totalMessages := []*model.Message{} filteredMessages := []*model.Message{}
for _, dbInfo := range dbInfos { for _, dbInfo := range dbInfos {
// 检查上下文是否已取消 // 检查上下文是否已取消
@@ -245,172 +267,137 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
continue continue
} }
messages, err := ds.getMessagesFromDB(ctx, db, dbInfo, startTime, endTime, talker) // 对每个talker进行查询
if err != nil { for _, talkerItem := range talkers {
log.Err(err).Msgf("从数据库 %s 获取消息失败", dbInfo.FilePath) // 构建查询条件
continue conditions := []string{"Sequence >= ? AND Sequence <= ?"}
} args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000}
totalMessages = append(totalMessages, messages...) // 添加talker条件
talkerID, ok := dbInfo.TalkerMap[talkerItem]
if ok {
conditions = append(conditions, "TalkerId = ?")
args = append(args, talkerID)
} else {
conditions = append(conditions, "StrTalker = ?")
args = append(args, talkerItem)
}
if limit+offset > 0 && len(totalMessages) >= limit+offset { query := fmt.Sprintf(`
break SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender,
Type, SubType, StrContent, CompressContent, BytesExtra
FROM MSG
WHERE %s
ORDER BY Sequence ASC
`, strings.Join(conditions, " AND "))
// 执行查询
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
// 如果表不存在跳过此talker
if strings.Contains(err.Error(), "no such table") {
continue
}
log.Err(err).Msgf("从数据库 %s 查询消息失败", dbInfo.FilePath)
continue
}
// 处理查询结果,在读取时进行过滤
for rows.Next() {
var msg model.MessageV3
var compressContent []byte
var bytesExtra []byte
err := rows.Scan(
&msg.MsgSvrID,
&msg.Sequence,
&msg.CreateTime,
&msg.StrTalker,
&msg.IsSender,
&msg.Type,
&msg.SubType,
&msg.StrContent,
&compressContent,
&bytesExtra,
)
if err != nil {
rows.Close()
return nil, errors.ScanRowFailed(err)
}
msg.CompressContent = compressContent
msg.BytesExtra = bytesExtra
// 将消息转换为标准格式
message := msg.Wrap()
// 应用sender过滤
if len(senders) > 0 {
senderMatch := false
for _, s := range senders {
if message.Sender == s {
senderMatch = true
break
}
}
if !senderMatch {
continue // 不匹配sender跳过此消息
}
}
// 应用keyword过滤
if regex != nil {
plainText := message.PlainTextContent()
if !regex.MatchString(plainText) {
continue // 不匹配keyword跳过此消息
}
}
// 通过所有过滤条件,保留此消息
filteredMessages = append(filteredMessages, message)
// 检查是否已经满足分页处理数量
if limit > 0 && len(filteredMessages) >= offset+limit {
// 已经获取了足够的消息,可以提前返回
rows.Close()
// 对所有消息按时间排序
sort.Slice(filteredMessages, func(i, j int) bool {
return filteredMessages[i].Seq < filteredMessages[j].Seq
})
// 处理分页
if offset >= len(filteredMessages) {
return []*model.Message{}, nil
}
end := offset + limit
if end > len(filteredMessages) {
end = len(filteredMessages)
}
return filteredMessages[offset:end], nil
}
}
rows.Close()
} }
} }
// 对所有消息按时间排序 // 对所有消息按时间排序
sort.Slice(totalMessages, func(i, j int) bool { sort.Slice(filteredMessages, func(i, j int) bool {
return totalMessages[i].Seq < totalMessages[j].Seq return filteredMessages[i].Seq < filteredMessages[j].Seq
}) })
// 处理分页 // 处理分页
if limit > 0 { if limit > 0 {
if offset >= len(totalMessages) { if offset >= len(filteredMessages) {
return []*model.Message{}, nil return []*model.Message{}, nil
} }
end := offset + limit end := offset + limit
if end > len(totalMessages) { if end > len(filteredMessages) {
end = len(totalMessages) end = len(filteredMessages)
} }
return totalMessages[offset:end], nil return filteredMessages[offset:end], nil
} }
return totalMessages, nil return filteredMessages, nil
}
// getMessagesSingleFile 从单个数据库文件获取消息
func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) {
db, err := ds.dbm.OpenDB(dbInfo.FilePath)
if err != nil {
return nil, errors.DBConnectFailed(dbInfo.FilePath, nil)
}
// 构建查询条件
conditions := []string{"Sequence >= ? AND Sequence <= ?"}
args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000}
if len(talker) > 0 {
// TalkerId 有索引,优先使用
talkerID, ok := dbInfo.TalkerMap[talker]
if ok {
conditions = append(conditions, "TalkerId = ?")
args = append(args, talkerID)
} else {
conditions = append(conditions, "StrTalker = ?")
args = append(args, talker)
}
}
query := fmt.Sprintf(`
SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender,
Type, SubType, StrContent, CompressContent, BytesExtra
FROM MSG
WHERE %s
ORDER BY Sequence ASC
`, strings.Join(conditions, " AND "))
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
if offset > 0 {
query += fmt.Sprintf(" OFFSET %d", offset)
}
}
// 执行查询
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
// 处理查询结果
totalMessages := []*model.Message{}
for rows.Next() {
var msg model.MessageV3
var compressContent []byte
var bytesExtra []byte
err := rows.Scan(
&msg.MsgSvrID,
&msg.Sequence,
&msg.CreateTime,
&msg.StrTalker,
&msg.IsSender,
&msg.Type,
&msg.SubType,
&msg.StrContent,
&compressContent,
&bytesExtra,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
msg.CompressContent = compressContent
msg.BytesExtra = bytesExtra
totalMessages = append(totalMessages, msg.Wrap())
}
return totalMessages, nil
}
// getMessagesFromDB 从数据库获取消息
func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string) ([]*model.Message, error) {
// 构建查询条件
conditions := []string{"Sequence >= ? AND Sequence <= ?"}
args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000}
if len(talker) > 0 {
talkerID, ok := dbInfo.TalkerMap[talker]
if ok {
conditions = append(conditions, "TalkerId = ?")
args = append(args, talkerID)
} else {
conditions = append(conditions, "StrTalker = ?")
args = append(args, talker)
}
}
query := fmt.Sprintf(`
SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender,
Type, SubType, StrContent, CompressContent, BytesExtra
FROM MSG
WHERE %s
ORDER BY Sequence ASC
`, strings.Join(conditions, " AND "))
// 执行查询
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
// 处理查询结果
messages := []*model.Message{}
for rows.Next() {
var msg model.MessageV3
var compressContent []byte
var bytesExtra []byte
err := rows.Scan(
&msg.MsgSvrID,
&msg.Sequence,
&msg.CreateTime,
&msg.StrTalker,
&msg.IsSender,
&msg.Type,
&msg.SubType,
&msg.StrContent,
&compressContent,
&bytesExtra,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
msg.CompressContent = compressContent
msg.BytesExtra = bytesExtra
messages = append(messages, msg.Wrap())
}
return messages, nil
} }
// GetContacts 实现获取联系人信息的方法 // GetContacts 实现获取联系人信息的方法

View File

@@ -18,8 +18,8 @@ func (r *Repository) initChatRoomCache(ctx context.Context) error {
} }
chatRoomMap := make(map[string]*model.ChatRoom) chatRoomMap := make(map[string]*model.ChatRoom)
remarkToChatRoom := make(map[string]*model.ChatRoom) remarkToChatRoom := make(map[string][]*model.ChatRoom)
nickNameToChatRoom := make(map[string]*model.ChatRoom) nickNameToChatRoom := make(map[string][]*model.ChatRoom)
chatRoomList := make([]string, 0) chatRoomList := make([]string, 0)
chatRoomRemark := make([]string, 0) chatRoomRemark := make([]string, 0)
chatRoomNickName := make([]string, 0) chatRoomNickName := make([]string, 0)
@@ -30,11 +30,21 @@ func (r *Repository) initChatRoomCache(ctx context.Context) error {
chatRoomMap[chatRoom.Name] = chatRoom chatRoomMap[chatRoom.Name] = chatRoom
chatRoomList = append(chatRoomList, chatRoom.Name) chatRoomList = append(chatRoomList, chatRoom.Name)
if chatRoom.Remark != "" { if chatRoom.Remark != "" {
remarkToChatRoom[chatRoom.Remark] = chatRoom remark, ok := remarkToChatRoom[chatRoom.Remark]
if !ok {
remark = make([]*model.ChatRoom, 0)
}
remark = append(remark, chatRoom)
remarkToChatRoom[chatRoom.Remark] = remark
chatRoomRemark = append(chatRoomRemark, chatRoom.Remark) chatRoomRemark = append(chatRoomRemark, chatRoom.Remark)
} }
if chatRoom.NickName != "" { if chatRoom.NickName != "" {
nickNameToChatRoom[chatRoom.NickName] = chatRoom nickName, ok := nickNameToChatRoom[chatRoom.NickName]
if !ok {
nickName = make([]*model.ChatRoom, 0)
}
nickName = append(nickName, chatRoom)
nickNameToChatRoom[chatRoom.NickName] = nickName
chatRoomNickName = append(chatRoomNickName, chatRoom.NickName) chatRoomNickName = append(chatRoomNickName, chatRoom.NickName)
} }
} }
@@ -49,11 +59,21 @@ func (r *Repository) initChatRoomCache(ctx context.Context) error {
chatRoomMap[contact.UserName] = chatRoom chatRoomMap[contact.UserName] = chatRoom
chatRoomList = append(chatRoomList, contact.UserName) chatRoomList = append(chatRoomList, contact.UserName)
if contact.Remark != "" { if contact.Remark != "" {
remarkToChatRoom[contact.Remark] = chatRoom remark, ok := remarkToChatRoom[chatRoom.Remark]
if !ok {
remark = make([]*model.ChatRoom, 0)
}
remark = append(remark, chatRoom)
remarkToChatRoom[chatRoom.Remark] = remark
chatRoomRemark = append(chatRoomRemark, contact.Remark) chatRoomRemark = append(chatRoomRemark, contact.Remark)
} }
if contact.NickName != "" { if contact.NickName != "" {
nickNameToChatRoom[contact.NickName] = chatRoom nickName, ok := nickNameToChatRoom[chatRoom.NickName]
if !ok {
nickName = make([]*model.ChatRoom, 0)
}
nickName = append(nickName, chatRoom)
nickNameToChatRoom[chatRoom.NickName] = nickName
chatRoomNickName = append(chatRoomNickName, contact.NickName) chatRoomNickName = append(chatRoomNickName, contact.NickName)
} }
} }
@@ -63,9 +83,12 @@ func (r *Repository) initChatRoomCache(ctx context.Context) error {
sort.Strings(chatRoomNickName) sort.Strings(chatRoomNickName)
r.chatRoomCache = chatRoomMap r.chatRoomCache = chatRoomMap
r.chatRoomList = chatRoomList
r.remarkToChatRoom = remarkToChatRoom r.remarkToChatRoom = remarkToChatRoom
r.nickNameToChatRoom = nickNameToChatRoom r.nickNameToChatRoom = nickNameToChatRoom
r.chatRoomList = chatRoomList
r.chatRoomRemark = chatRoomRemark
r.chatRoomNickName = chatRoomNickName
return nil return nil
} }
@@ -75,7 +98,7 @@ func (r *Repository) GetChatRooms(ctx context.Context, key string, limit, offset
if key != "" { if key != "" {
ret = r.findChatRooms(key) ret = r.findChatRooms(key)
if len(ret) == 0 { if len(ret) == 0 {
return nil, errors.ChatRoomNotFound(key) return []*model.ChatRoom{}, nil
} }
if limit > 0 { if limit > 0 {
@@ -129,21 +152,21 @@ func (r *Repository) findChatRoom(key string) *model.ChatRoom {
return chatRoom return chatRoom
} }
if chatRoom, ok := r.remarkToChatRoom[key]; ok { if chatRoom, ok := r.remarkToChatRoom[key]; ok {
return chatRoom return chatRoom[0]
} }
if chatRoom, ok := r.nickNameToChatRoom[key]; ok { if chatRoom, ok := r.nickNameToChatRoom[key]; ok {
return chatRoom return chatRoom[0]
} }
// Contain // Contain
for _, remark := range r.chatRoomRemark { for _, remark := range r.chatRoomRemark {
if strings.Contains(remark, key) { if strings.Contains(remark, key) {
return r.remarkToChatRoom[remark] return r.remarkToChatRoom[remark][0]
} }
} }
for _, nickName := range r.chatRoomNickName { for _, nickName := range r.chatRoomNickName {
if strings.Contains(nickName, key) { if strings.Contains(nickName, key) {
return r.nickNameToChatRoom[nickName] return r.nickNameToChatRoom[nickName][0]
} }
} }
@@ -157,26 +180,42 @@ func (r *Repository) findChatRooms(key string) []*model.ChatRoom {
ret = append(ret, chatRoom) ret = append(ret, chatRoom)
distinct[chatRoom.Name] = true distinct[chatRoom.Name] = true
} }
if chatRoom, ok := r.remarkToChatRoom[key]; ok && !distinct[chatRoom.Name] { if chatRooms, ok := r.remarkToChatRoom[key]; ok {
ret = append(ret, chatRoom) for _, chatRoom := range chatRooms {
distinct[chatRoom.Name] = true if !distinct[chatRoom.Name] {
ret = append(ret, chatRoom)
distinct[chatRoom.Name] = true
}
}
} }
if chatRoom, ok := r.nickNameToChatRoom[key]; ok && !distinct[chatRoom.Name] { if chatRooms, ok := r.nickNameToChatRoom[key]; ok {
ret = append(ret, chatRoom) for _, chatRoom := range chatRooms {
distinct[chatRoom.Name] = true if !distinct[chatRoom.Name] {
ret = append(ret, chatRoom)
distinct[chatRoom.Name] = true
}
}
} }
// Contain // Contain
for _, remark := range r.chatRoomRemark { for _, remark := range r.chatRoomRemark {
if strings.Contains(remark, key) && !distinct[r.remarkToChatRoom[remark].Name] { if strings.Contains(remark, key) {
ret = append(ret, r.remarkToChatRoom[remark]) for _, chatRoom := range r.remarkToChatRoom[remark] {
distinct[r.remarkToChatRoom[remark].Name] = true if !distinct[chatRoom.Name] {
ret = append(ret, chatRoom)
distinct[chatRoom.Name] = true
}
}
} }
} }
for _, nickName := range r.chatRoomNickName { for _, nickName := range r.chatRoomNickName {
if strings.Contains(nickName, key) && !distinct[r.nickNameToChatRoom[nickName].Name] { if strings.Contains(nickName, key) {
ret = append(ret, r.nickNameToChatRoom[nickName]) for _, chatRoom := range r.nickNameToChatRoom[nickName] {
distinct[r.nickNameToChatRoom[nickName].Name] = true if !distinct[chatRoom.Name] {
ret = append(ret, chatRoom)
distinct[chatRoom.Name] = true
}
}
} }
} }

View File

@@ -18,9 +18,9 @@ func (r *Repository) initContactCache(ctx context.Context) error {
} }
contactMap := make(map[string]*model.Contact) contactMap := make(map[string]*model.Contact)
aliasMap := make(map[string]*model.Contact) aliasMap := make(map[string][]*model.Contact)
remarkMap := make(map[string]*model.Contact) remarkMap := make(map[string][]*model.Contact)
nickNameMap := make(map[string]*model.Contact) nickNameMap := make(map[string][]*model.Contact)
chatRoomUserMap := make(map[string]*model.Contact) chatRoomUserMap := make(map[string]*model.Contact)
chatRoomInContactMap := make(map[string]*model.Contact) chatRoomInContactMap := make(map[string]*model.Contact)
contactList := make([]string, 0) contactList := make([]string, 0)
@@ -34,15 +34,30 @@ func (r *Repository) initContactCache(ctx context.Context) error {
// 建立快速查找索引 // 建立快速查找索引
if contact.Alias != "" { if contact.Alias != "" {
aliasMap[contact.Alias] = contact alias, ok := aliasMap[contact.Alias]
if !ok {
alias = make([]*model.Contact, 0)
}
alias = append(alias, contact)
aliasMap[contact.Alias] = alias
aliasList = append(aliasList, contact.Alias) aliasList = append(aliasList, contact.Alias)
} }
if contact.Remark != "" { if contact.Remark != "" {
remarkMap[contact.Remark] = contact remark, ok := remarkMap[contact.Remark]
if !ok {
remark = make([]*model.Contact, 0)
}
remark = append(remark, contact)
remarkMap[contact.Remark] = remark
remarkList = append(remarkList, contact.Remark) remarkList = append(remarkList, contact.Remark)
} }
if contact.NickName != "" { if contact.NickName != "" {
nickNameMap[contact.NickName] = contact nickName, ok := nickNameMap[contact.NickName]
if !ok {
nickName = make([]*model.Contact, 0)
}
nickName = append(nickName, contact)
nickNameMap[contact.NickName] = nickName
nickNameList = append(nickNameList, contact.NickName) nickNameList = append(nickNameList, contact.NickName)
} }
@@ -88,7 +103,7 @@ func (r *Repository) GetContacts(ctx context.Context, key string, limit, offset
if key != "" { if key != "" {
ret = r.findContacts(key) ret = r.findContacts(key)
if len(ret) == 0 { if len(ret) == 0 {
return nil, errors.ContactNotFound(key) return []*model.Contact{}, nil
} }
if limit > 0 { if limit > 0 {
end := offset + limit end := offset + limit
@@ -124,29 +139,29 @@ func (r *Repository) findContact(key string) *model.Contact {
return contact return contact
} }
if contact, ok := r.aliasToContact[key]; ok { if contact, ok := r.aliasToContact[key]; ok {
return contact return contact[0]
} }
if contact, ok := r.remarkToContact[key]; ok { if contact, ok := r.remarkToContact[key]; ok {
return contact return contact[0]
} }
if contact, ok := r.nickNameToContact[key]; ok { if contact, ok := r.nickNameToContact[key]; ok {
return contact return contact[0]
} }
// Contain // Contain
for _, alias := range r.aliasList { for _, alias := range r.aliasList {
if strings.Contains(alias, key) { if strings.Contains(alias, key) {
return r.aliasToContact[alias] return r.aliasToContact[alias][0]
} }
} }
for _, remark := range r.remarkList { for _, remark := range r.remarkList {
if strings.Contains(remark, key) { if strings.Contains(remark, key) {
return r.remarkToContact[remark] return r.remarkToContact[remark][0]
} }
} }
for _, nickName := range r.nickNameList { for _, nickName := range r.nickNameList {
if strings.Contains(nickName, key) { if strings.Contains(nickName, key) {
return r.nickNameToContact[nickName] return r.nickNameToContact[nickName][0]
} }
} }
return nil return nil
@@ -159,37 +174,62 @@ func (r *Repository) findContacts(key string) []*model.Contact {
ret = append(ret, contact) ret = append(ret, contact)
distinct[contact.UserName] = true distinct[contact.UserName] = true
} }
if contact, ok := r.aliasToContact[key]; ok && !distinct[contact.UserName] { if contacts, ok := r.aliasToContact[key]; ok {
ret = append(ret, contact) for _, contact := range contacts {
distinct[contact.UserName] = true if !distinct[contact.UserName] {
ret = append(ret, contact)
distinct[contact.UserName] = true
}
}
} }
if contact, ok := r.remarkToContact[key]; ok && !distinct[contact.UserName] { if contacts, ok := r.remarkToContact[key]; ok {
ret = append(ret, contact) for _, contact := range contacts {
distinct[contact.UserName] = true if !distinct[contact.UserName] {
ret = append(ret, contact)
distinct[contact.UserName] = true
}
}
} }
if contact, ok := r.nickNameToContact[key]; ok && !distinct[contact.UserName] { if contacts, ok := r.nickNameToContact[key]; ok {
ret = append(ret, contact) for _, contact := range contacts {
distinct[contact.UserName] = true if !distinct[contact.UserName] {
ret = append(ret, contact)
distinct[contact.UserName] = true
}
}
} }
// Contain // Contain
for _, alias := range r.aliasList { for _, alias := range r.aliasList {
if strings.Contains(alias, key) && !distinct[r.aliasToContact[alias].UserName] { if strings.Contains(alias, key) {
ret = append(ret, r.aliasToContact[alias]) for _, contact := range r.aliasToContact[alias] {
distinct[r.aliasToContact[alias].UserName] = true if !distinct[contact.UserName] {
ret = append(ret, contact)
distinct[contact.UserName] = true
}
}
} }
} }
for _, remark := range r.remarkList { for _, remark := range r.remarkList {
if strings.Contains(remark, key) && !distinct[r.remarkToContact[remark].UserName] { if strings.Contains(remark, key) {
ret = append(ret, r.remarkToContact[remark]) for _, contact := range r.remarkToContact[remark] {
distinct[r.remarkToContact[remark].UserName] = true if !distinct[contact.UserName] {
ret = append(ret, contact)
distinct[contact.UserName] = true
}
}
} }
} }
for _, nickName := range r.nickNameList { for _, nickName := range r.nickNameList {
if strings.Contains(nickName, key) && !distinct[r.nickNameToContact[nickName].UserName] { if strings.Contains(nickName, key) {
ret = append(ret, r.nickNameToContact[nickName]) for _, contact := range r.nickNameToContact[nickName] {
distinct[r.nickNameToContact[nickName].UserName] = true if !distinct[contact.UserName] {
ret = append(ret, contact)
distinct[contact.UserName] = true
}
}
} }
} }
return ret return ret
} }

View File

@@ -2,23 +2,20 @@ package repository
import ( import (
"context" "context"
"strings"
"time" "time"
"github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/internal/model"
"github.com/sjzar/chatlog/pkg/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
// GetMessages 实现 Repository 接口的 GetMessages 方法 // GetMessages 实现 Repository 接口的 GetMessages 方法
func (r *Repository) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { func (r *Repository) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) {
if contact, _ := r.GetContact(ctx, talker); contact != nil { talker, sender = r.parseTalkerAndSender(ctx, talker, sender)
talker = contact.UserName messages, err := r.ds.GetMessages(ctx, startTime, endTime, talker, sender, keyword, limit, offset)
} else if chatRoom, _ := r.GetChatRoom(ctx, talker); chatRoom != nil {
talker = chatRoom.Name
}
messages, err := r.ds.GetMessages(ctx, startTime, endTime, talker, limit, offset)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -62,3 +59,53 @@ func (r *Repository) enrichMessage(msg *model.Message) {
} }
} }
} }
func (r *Repository) parseTalkerAndSender(ctx context.Context, talker, sender string) (string, string) {
displayName2User := make(map[string]string)
users := make(map[string]bool)
talkers := util.Str2List(talker, ",")
if len(talkers) > 0 {
for i := 0; i < len(talkers); i++ {
if contact, _ := r.GetContact(ctx, talkers[i]); contact != nil {
talkers[i] = contact.UserName
} else if chatRoom, _ := r.GetChatRoom(ctx, talker); chatRoom != nil {
talkers[i] = chatRoom.Name
}
}
// 获取群聊的用户列表
for i := 0; i < len(talkers); i++ {
if chatRoom, _ := r.GetChatRoom(ctx, talkers[i]); chatRoom != nil {
for user, displayName := range chatRoom.User2DisplayName {
displayName2User[displayName] = user
}
for _, user := range chatRoom.Users {
users[user.UserName] = true
}
}
}
talker = strings.Join(talkers, ",")
}
senders := util.Str2List(sender, ",")
if len(senders) > 0 {
for i := 0; i < len(senders); i++ {
if user, ok := displayName2User[senders[i]]; ok {
senders[i] = user
} else {
// FIXME 大量群聊用户名称重复,无法直接通过 GetContact 获取 ID后续再优化
for user := range users {
if contact := r.getFullContact(user); contact != nil {
if contact.DisplayName() == senders[i] {
senders[i] = user
break
}
}
}
}
}
sender = strings.Join(senders, ",")
}
return talker, sender
}

View File

@@ -17,9 +17,9 @@ type Repository struct {
// Cache for contact // Cache for contact
contactCache map[string]*model.Contact contactCache map[string]*model.Contact
aliasToContact map[string]*model.Contact aliasToContact map[string][]*model.Contact
remarkToContact map[string]*model.Contact remarkToContact map[string][]*model.Contact
nickNameToContact map[string]*model.Contact nickNameToContact map[string][]*model.Contact
chatRoomInContact map[string]*model.Contact chatRoomInContact map[string]*model.Contact
contactList []string contactList []string
aliasList []string aliasList []string
@@ -28,8 +28,8 @@ type Repository struct {
// Cache for chat room // Cache for chat room
chatRoomCache map[string]*model.ChatRoom chatRoomCache map[string]*model.ChatRoom
remarkToChatRoom map[string]*model.ChatRoom remarkToChatRoom map[string][]*model.ChatRoom
nickNameToChatRoom map[string]*model.ChatRoom nickNameToChatRoom map[string][]*model.ChatRoom
chatRoomList []string chatRoomList []string
chatRoomRemark []string chatRoomRemark []string
chatRoomNickName []string chatRoomNickName []string
@@ -43,17 +43,17 @@ func New(ds datasource.DataSource) (*Repository, error) {
r := &Repository{ r := &Repository{
ds: ds, ds: ds,
contactCache: make(map[string]*model.Contact), contactCache: make(map[string]*model.Contact),
aliasToContact: make(map[string]*model.Contact), aliasToContact: make(map[string][]*model.Contact),
remarkToContact: make(map[string]*model.Contact), remarkToContact: make(map[string][]*model.Contact),
nickNameToContact: make(map[string]*model.Contact), nickNameToContact: make(map[string][]*model.Contact),
chatRoomUserToInfo: make(map[string]*model.Contact), chatRoomUserToInfo: make(map[string]*model.Contact),
contactList: make([]string, 0), contactList: make([]string, 0),
aliasList: make([]string, 0), aliasList: make([]string, 0),
remarkList: make([]string, 0), remarkList: make([]string, 0),
nickNameList: make([]string, 0), nickNameList: make([]string, 0),
chatRoomCache: make(map[string]*model.ChatRoom), chatRoomCache: make(map[string]*model.ChatRoom),
remarkToChatRoom: make(map[string]*model.ChatRoom), remarkToChatRoom: make(map[string][]*model.ChatRoom),
nickNameToChatRoom: make(map[string]*model.ChatRoom), nickNameToChatRoom: make(map[string][]*model.ChatRoom),
chatRoomList: make([]string, 0), chatRoomList: make([]string, 0),
chatRoomRemark: make([]string, 0), chatRoomRemark: make([]string, 0),
chatRoomNickName: make([]string, 0), chatRoomNickName: make([]string, 0),

View File

@@ -57,11 +57,11 @@ func (w *DB) Initialize() error {
return nil return nil
} }
func (w *DB) GetMessages(start, end time.Time, talker string, limit, offset int) ([]*model.Message, error) { func (w *DB) GetMessages(start, end time.Time, talker string, sender string, keyword string, limit, offset int) ([]*model.Message, error) {
ctx := context.Background() ctx := context.Background()
// 使用 repository 获取消息 // 使用 repository 获取消息
messages, err := w.repo.GetMessages(ctx, start, end, talker, limit, offset) messages, err := w.repo.GetMessages(ctx, start, end, talker, sender, keyword, limit, offset)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -17,6 +17,7 @@ import (
// Format defines the header and extension for different image types // Format defines the header and extension for different image types
type Format struct { type Format struct {
Header []byte Header []byte
AesKey []byte
Ext string Ext string
} }
@@ -29,10 +30,13 @@ var (
BMP = Format{Header: []byte{0x42, 0x4D}, Ext: "bmp"} BMP = Format{Header: []byte{0x42, 0x4D}, Ext: "bmp"}
Formats = []Format{JPG, PNG, GIF, TIFF, BMP} Formats = []Format{JPG, PNG, GIF, TIFF, BMP}
V4Format1 = Format{Header: []byte{0x07, 0x08, 0x56, 0x31}, AesKey: []byte("cfcd208495d565ef")}
V4Format2 = Format{Header: []byte{0x07, 0x08, 0x56, 0x32}, AesKey: []byte("0000000000000000")} // FIXME
V4Formats = []Format{V4Format1, V4Format2}
// WeChat v4 related constants // WeChat v4 related constants
V4XorKey byte = 0x37 // Default XOR key for WeChat v4 dat files V4XorKey byte = 0x37 // Default XOR key for WeChat v4 dat files
V4DatHeader = []byte{0x07, 0x08, 0x56, 0x31} // WeChat v4 dat file header JpgTail = []byte{0xFF, 0xD9} // JPG file tail marker
JpgTail = []byte{0xFF, 0xD9} // JPG file tail marker
) )
// Dat2Image converts WeChat dat file data to image data // Dat2Image converts WeChat dat file data to image data
@@ -43,8 +47,12 @@ func Dat2Image(data []byte) ([]byte, string, error) {
} }
// Check if this is a WeChat v4 dat file // Check if this is a WeChat v4 dat file
if len(data) >= 6 && bytes.Equal(data[:4], V4DatHeader) { if len(data) >= 6 {
return Dat2ImageV4(data) for _, format := range V4Formats {
if bytes.Equal(data[:4], format.Header) {
return Dat2ImageV4(data, format.AesKey)
}
}
} }
// For older WeChat versions, use XOR decryption // For older WeChat versions, use XOR decryption
@@ -134,7 +142,7 @@ func ScanAndSetXorKey(dirPath string) (byte, error) {
} }
// Check if it's a WeChat v4 dat file // Check if it's a WeChat v4 dat file
if len(data) < 6 || !bytes.Equal(data[:4], V4DatHeader) { if len(data) < 6 || (!bytes.Equal(data[:4], V4Format1.Header) && !bytes.Equal(data[:4], V4Format2.Header)) {
return nil return nil
} }
@@ -179,13 +187,13 @@ func ScanAndSetXorKey(dirPath string) (byte, error) {
// Dat2ImageV4 processes WeChat v4 dat image files // Dat2ImageV4 processes WeChat v4 dat image files
// WeChat v4 uses a combination of AES-ECB and XOR encryption // WeChat v4 uses a combination of AES-ECB and XOR encryption
func Dat2ImageV4(data []byte) ([]byte, string, error) { func Dat2ImageV4(data []byte, aeskey []byte) ([]byte, string, error) {
if len(data) < 15 { if len(data) < 15 {
return nil, "", fmt.Errorf("data length is too short for WeChat v4 format: %d", len(data)) return nil, "", fmt.Errorf("data length is too short for WeChat v4 format: %d", len(data))
} }
// Parse dat file header: // Parse dat file header:
// - 6 bytes: 0x07085631 (dat file identifier) // - 6 bytes: 0x07085631 or 0x07085632 (dat file identifier)
// - 4 bytes: int (little-endian) AES-ECB128 encryption length // - 4 bytes: int (little-endian) AES-ECB128 encryption length
// - 4 bytes: int (little-endian) XOR encryption length // - 4 bytes: int (little-endian) XOR encryption length
// - 1 byte: 0x01 (unknown) // - 1 byte: 0x01 (unknown)
@@ -206,7 +214,7 @@ func Dat2ImageV4(data []byte) ([]byte, string, error) {
} }
// Decrypt AES part // Decrypt AES part
aesDecryptedData, err := decryptAESECB(fileData[:aesEncryptLen0], []byte("cfcd208495d565ef")) aesDecryptedData, err := decryptAESECB(fileData[:aesEncryptLen0], aeskey)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("AES decrypt error: %v", err) return nil, "", fmt.Errorf("AES decrypt error: %v", err)
} }

View File

@@ -3,6 +3,7 @@ package util
import ( import (
"fmt" "fmt"
"strconv" "strconv"
"strings"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
) )
@@ -45,3 +46,26 @@ func IsNumeric(s string) bool {
func SplitInt64ToTwoInt32(input int64) (int64, int64) { func SplitInt64ToTwoInt32(input int64) (int64, int64) {
return input & 0xFFFFFFFF, input >> 32 return input & 0xFFFFFFFF, input >> 32
} }
func Str2List(str string, sep string) []string {
list := make([]string, 0)
if str == "" {
return list
}
listMap := make(map[string]bool)
for _, elem := range strings.Split(str, sep) {
elem = strings.TrimSpace(elem)
if len(elem) == 0 {
continue
}
if _, ok := listMap[elem]; ok {
continue
}
listMap[elem] = true
list = append(list, elem)
}
return list
}

View File

@@ -582,8 +582,8 @@ func adjustStartTime(t time.Time, g TimeGranularity) time.Time {
func adjustEndTime(t time.Time, g TimeGranularity) time.Time { func adjustEndTime(t time.Time, g TimeGranularity) time.Time {
switch g { switch g {
case GranularitySecond, GranularityMinute, GranularityHour: case GranularitySecond, GranularityMinute, GranularityHour:
// 对于精确到秒/分钟/小时的时间,设置为当天结束 // 对于精确到秒/分钟/小时的时间,保持原样
return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, t.Location()) return t
case GranularityDay: case GranularityDay:
// 精确到天,设置为当天结束 // 精确到天,设置为当天结束
return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, t.Location()) return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, t.Location())
@@ -634,3 +634,25 @@ func isValidDate(year, month, day int) bool {
return day <= daysInMonth return day <= daysInMonth
} }
func PerfectTimeFormat(start time.Time, end time.Time) string {
endTime := end
// 如果结束时间是某一天的 0 点整,将其减去 1 秒,视为前一天的结束
if endTime.Hour() == 0 && endTime.Minute() == 0 && endTime.Second() == 0 && endTime.Nanosecond() == 0 {
endTime = endTime.Add(-time.Second) // 减去 1 秒
}
// 判断是否跨年
if start.Year() != endTime.Year() {
return "2006-01-02 15:04:05" // 完整格式,包含年月日时分秒
}
// 判断是否跨天(但在同一年内)
if start.YearDay() != endTime.YearDay() {
return "01-02 15:04:05" // 月日时分秒格式
}
// 在同一天内
return "15:04:05" // 只显示时分秒
}