diff --git a/README.md b/README.md index 2bc3f26..d84a55e 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ go install github.com/sjzar/chatlog@latest ### macOS 版本提示 -1. macOS 用户在获取密钥前,需要确认已经关闭 SIP 并安装 Xcode。由于 macOS 的安全机制,在正常情况在无法读取微信进程的内存数据,所以需要临时关闭 SIP。关闭 SIP 的方法: +1. macOS 用户在获取密钥前,需要确认已经关闭 SIP 并安装 Xcode Command Line Tools。由于 macOS 的安全机制,在正常情况在无法读取微信进程的内存数据,所以需要临时关闭 SIP。关闭 SIP 的方法: ```shell # 1. 进入恢复模式 @@ -72,10 +72,17 @@ go install github.com/sjzar/chatlog@latest # 4. 重启系统 ``` -2. 目前的 macOS 版本方案依赖 `lldb` 工具,所以需要安装 Xcode,可以从 App Store 进行下载。 +2. 目前的 macOS 版本方案依赖 `lldb` 工具,所以需要安装 Xcode Command Line Tools。 + +```shell +# 在 terminal 执行以下命令安装 Xcode Command Line Tools: +xcode-select --install +``` 3. 仅获取数据密钥步骤需要关闭 SIP;获取数据密钥后即可重新打开 SIP,不影响解密数据和 HTTP 服务的运行。 +4. 如果是 Apple Silicon 芯片的 mac 用户,请检查 微信、chatlog、terminal 均不要运行在 Rosetta 模式下运行,否则可能无法获取密钥。 + ### Terminal UI 模式 1. 启动程序: diff --git a/internal/chatlog/app.go b/internal/chatlog/app.go index 4f77bb0..e3f5576 100644 --- a/internal/chatlog/app.go +++ b/internal/chatlog/app.go @@ -2,14 +2,17 @@ package chatlog import ( "fmt" + "path/filepath" "runtime" "time" "github.com/sjzar/chatlog/internal/chatlog/ctx" "github.com/sjzar/chatlog/internal/ui/footer" + "github.com/sjzar/chatlog/internal/ui/form" "github.com/sjzar/chatlog/internal/ui/help" "github.com/sjzar/chatlog/internal/ui/infobar" "github.com/sjzar/chatlog/internal/ui/menu" + "github.com/sjzar/chatlog/internal/wechat" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" @@ -54,6 +57,8 @@ func NewApp(ctx *ctx.Context, m *Manager) *App { app.initMenu() + app.updateMenuItemsState() + return app } @@ -91,6 +96,33 @@ func (a *App) Stop() { a.Application.Stop() } +func (a *App) updateMenuItemsState() { + // 查找并更新自动解密菜单项 + for _, item := range a.menu.GetItems() { + // 更新自动解密菜单项 + if item.Index == 5 { + if a.ctx.AutoDecrypt { + item.Name = "停止自动解密" + item.Description = "停止监控数据目录更新,不再自动解密新增数据" + } else { + item.Name = "开启自动解密" + item.Description = "监控数据目录更新,自动解密新增数据" + } + } + + // 更新HTTP服务菜单项 + if item.Index == 4 { + if a.ctx.HTTPEnabled { + item.Name = "停止 HTTP 服务" + item.Description = "停止本地 HTTP & MCP 服务器" + } else { + item.Name = "启动 HTTP 服务" + item.Description = "启动本地 HTTP & MCP 服务器" + } + } + } +} + func (a *App) switchTab(step int) { index := (a.activeTab + step) % a.tabCount if index < 0 { @@ -109,17 +141,29 @@ func (a *App) refresh() { case <-a.stopRefresh: return case <-tick.C: + if a.ctx.AutoDecrypt || a.ctx.HTTPEnabled { + a.m.RefreshSession() + } a.infoBar.UpdateAccount(a.ctx.Account) a.infoBar.UpdateBasicInfo(a.ctx.PID, a.ctx.FullVersion, a.ctx.ExePath) a.infoBar.UpdateStatus(a.ctx.Status) a.infoBar.UpdateDataKey(a.ctx.DataKey) + a.infoBar.UpdatePlatform(a.ctx.Platform) a.infoBar.UpdateDataUsageDir(a.ctx.DataUsage, a.ctx.DataDir) a.infoBar.UpdateWorkUsageDir(a.ctx.WorkUsage, a.ctx.WorkDir) + if a.ctx.LastSession.Unix() > 1000000000 { + a.infoBar.UpdateSession(a.ctx.LastSession.Format("2006-01-02 15:04:05")) + } if a.ctx.HTTPEnabled { a.infoBar.UpdateHTTPServer(fmt.Sprintf("[green][已启动][white] [%s]", a.ctx.HTTPAddr)) } else { a.infoBar.UpdateHTTPServer("[未启动]") } + if a.ctx.AutoDecrypt { + a.infoBar.UpdateAutoDecrypt("[green][已开启][white]") + } else { + a.infoBar.UpdateAutoDecrypt("[未开启]") + } a.Draw() } @@ -257,11 +301,11 @@ func (a *App) initMenu() { } else { // 启动成功 modal.SetText("已启动 HTTP 服务") - // 更改菜单项名称 - i.Name = "停止 HTTP 服务" - i.Description = "停止本地 HTTP & MCP 服务器" } + // 更改菜单项名称 + a.updateMenuItemsState() + // 添加确认按钮 modal.AddButtons([]string{"OK"}) modal.SetDoneFunc(func(buttonIndex int, buttonLabel string) { @@ -288,11 +332,89 @@ func (a *App) initMenu() { } else { // 停止成功 modal.SetText("已停止 HTTP 服务") - // 更改菜单项名称 - i.Name = "启动 HTTP 服务" - i.Description = "启动本地 HTTP & MCP 服务器" } + // 更改菜单项名称 + a.updateMenuItemsState() + + // 添加确认按钮 + modal.AddButtons([]string{"OK"}) + modal.SetDoneFunc(func(buttonIndex int, buttonLabel string) { + a.mainPages.RemovePage("modal") + }) + a.SetFocus(modal) + }) + }() + } + }, + } + + autoDecrypt := &menu.Item{ + Index: 5, + Name: "开启自动解密", + Description: "自动解密新增的数据文件", + Selected: func(i *menu.Item) { + modal := tview.NewModal() + + // 根据当前自动解密状态执行不同操作 + if !a.ctx.AutoDecrypt { + // 自动解密未开启,开启自动解密 + modal.SetText("正在开启自动解密...") + a.mainPages.AddPage("modal", modal, true, true) + a.SetFocus(modal) + + // 在后台开启自动解密 + go func() { + err := a.m.StartAutoDecrypt() + + // 在主线程中更新UI + a.QueueUpdateDraw(func() { + if err != nil { + // 开启失败 + modal.SetText("开启自动解密失败: " + err.Error()) + } else { + // 开启成功 + if a.ctx.Version == 3 { + modal.SetText("已开启自动解密\n3.x版本数据文件更新不及时,有低延迟需求请使用4.0版本") + } else { + modal.SetText("已开启自动解密") + } + } + + // 更改菜单项名称 + a.updateMenuItemsState() + + // 添加确认按钮 + modal.AddButtons([]string{"OK"}) + modal.SetDoneFunc(func(buttonIndex int, buttonLabel string) { + a.mainPages.RemovePage("modal") + }) + a.SetFocus(modal) + }) + }() + } else { + // 自动解密已开启,停止自动解密 + modal.SetText("正在停止自动解密...") + a.mainPages.AddPage("modal", modal, true, true) + a.SetFocus(modal) + + // 在后台停止自动解密 + go func() { + err := a.m.StopAutoDecrypt() + + // 在主线程中更新UI + a.QueueUpdateDraw(func() { + if err != nil { + // 停止失败 + modal.SetText("停止自动解密失败: " + err.Error()) + } else { + // 停止成功 + modal.SetText("已停止自动解密") + } + + // 更改菜单项名称 + a.updateMenuItemsState() + // 添加确认按钮 modal.AddButtons([]string{"OK"}) modal.SetDoneFunc(func(buttonIndex int, buttonLabel string) { @@ -306,19 +428,28 @@ func (a *App) initMenu() { } setting := &menu.Item{ - Index: 5, + Index: 6, Name: "设置", Description: "设置应用程序选项", Selected: a.settingSelected, } - a.menu.AddItem(setting) + selectAccount := &menu.Item{ + Index: 7, + Name: "切换账号", + Description: "切换当前操作的账号,可以选择进程或历史账号", + Selected: a.selectAccountSelected, + } + a.menu.AddItem(getDataKey) a.menu.AddItem(decryptData) a.menu.AddItem(httpServer) + a.menu.AddItem(autoDecrypt) + a.menu.AddItem(setting) + a.menu.AddItem(selectAccount) a.menu.AddItem(&menu.Item{ - Index: 6, + Index: 8, Name: "退出", Description: "退出程序", Selected: func(i *menu.Item) { @@ -347,6 +478,16 @@ func (a *App) settingSelected(i *menu.Item) { description: "配置数据解密后的存储目录", action: a.settingWorkDir, }, + { + name: "设置数据密钥", + description: "配置数据解密密钥", + action: a.settingDataKey, + }, + { + name: "设置数据目录", + description: "配置微信数据文件所在目录", + action: a.settingDataDir, + }, } subMenu := menu.NewSubMenu("设置") @@ -370,43 +511,279 @@ func (a *App) settingSelected(i *menu.Item) { // settingHTTPPort 设置 HTTP 端口 func (a *App) settingHTTPPort() { - // 实现端口设置逻辑 - // 这里可以使用 tview.InputField 让用户输入端口 - form := tview.NewForm(). - AddInputField("地址", a.ctx.HTTPAddr, 20, nil, func(text string) { - a.m.SetHTTPAddr(text) - }). - AddButton("保存", func() { - a.mainPages.RemovePage("submenu2") - a.showInfo("HTTP 地址已设置为 " + a.ctx.HTTPAddr) - }). - AddButton("取消", func() { - a.mainPages.RemovePage("submenu2") - }) - form.SetBorder(true).SetTitle("设置 HTTP 地址") + // 使用我们的自定义表单组件 + formView := form.NewForm("设置 HTTP 地址") - a.mainPages.AddPage("submenu2", form, true, true) - a.SetFocus(form) + // 临时存储用户输入的值 + tempHTTPAddr := a.ctx.HTTPAddr + + // 添加输入字段 - 不再直接设置HTTP地址,而是更新临时变量 + formView.AddInputField("地址", tempHTTPAddr, 0, nil, func(text string) { + tempHTTPAddr = text // 只更新临时变量 + }) + + // 添加按钮 - 点击保存时才设置HTTP地址 + formView.AddButton("保存", func() { + a.m.SetHTTPAddr(tempHTTPAddr) // 在这里设置HTTP地址 + a.mainPages.RemovePage("submenu2") + a.showInfo("HTTP 地址已设置为 " + a.ctx.HTTPAddr) + }) + + formView.AddButton("取消", func() { + a.mainPages.RemovePage("submenu2") + }) + + a.mainPages.AddPage("submenu2", formView, true, true) + a.SetFocus(formView) } // settingWorkDir 设置工作目录 func (a *App) settingWorkDir() { - // 实现工作目录设置逻辑 - form := tview.NewForm(). - AddInputField("工作目录", a.ctx.WorkDir, 40, nil, func(text string) { - a.ctx.SetWorkDir(text) - }). - AddButton("保存", func() { - a.mainPages.RemovePage("submenu2") - a.showInfo("工作目录已设置为 " + a.ctx.WorkDir) - }). - AddButton("取消", func() { - a.mainPages.RemovePage("submenu2") - }) - form.SetBorder(true).SetTitle("设置工作目录") + // 使用我们的自定义表单组件 + formView := form.NewForm("设置工作目录") - a.mainPages.AddPage("submenu2", form, true, true) - a.SetFocus(form) + // 临时存储用户输入的值 + tempWorkDir := a.ctx.WorkDir + + // 添加输入字段 - 不再直接设置工作目录,而是更新临时变量 + formView.AddInputField("工作目录", tempWorkDir, 0, nil, func(text string) { + tempWorkDir = text // 只更新临时变量 + }) + + // 添加按钮 - 点击保存时才设置工作目录 + formView.AddButton("保存", func() { + a.ctx.SetWorkDir(tempWorkDir) // 在这里设置工作目录 + a.mainPages.RemovePage("submenu2") + a.showInfo("工作目录已设置为 " + a.ctx.WorkDir) + }) + + formView.AddButton("取消", func() { + a.mainPages.RemovePage("submenu2") + }) + + a.mainPages.AddPage("submenu2", formView, true, true) + a.SetFocus(formView) +} + +// settingDataKey 设置数据密钥 +func (a *App) settingDataKey() { + // 使用我们的自定义表单组件 + formView := form.NewForm("设置数据密钥") + + // 临时存储用户输入的值 + tempDataKey := a.ctx.DataKey + + // 添加输入字段 - 不直接设置数据密钥,而是更新临时变量 + formView.AddInputField("数据密钥", tempDataKey, 0, nil, func(text string) { + tempDataKey = text // 只更新临时变量 + }) + + // 添加按钮 - 点击保存时才设置数据密钥 + formView.AddButton("保存", func() { + a.ctx.DataKey = tempDataKey // 设置数据密钥 + a.mainPages.RemovePage("submenu2") + a.showInfo("数据密钥已设置") + }) + + formView.AddButton("取消", func() { + a.mainPages.RemovePage("submenu2") + }) + + a.mainPages.AddPage("submenu2", formView, true, true) + a.SetFocus(formView) +} + +// settingDataDir 设置数据目录 +func (a *App) settingDataDir() { + // 使用我们的自定义表单组件 + formView := form.NewForm("设置数据目录") + + // 临时存储用户输入的值 + tempDataDir := a.ctx.DataDir + + // 添加输入字段 - 不直接设置数据目录,而是更新临时变量 + formView.AddInputField("数据目录", tempDataDir, 0, nil, func(text string) { + tempDataDir = text // 只更新临时变量 + }) + + // 添加按钮 - 点击保存时才设置数据目录 + formView.AddButton("保存", func() { + a.ctx.DataDir = tempDataDir // 设置数据目录 + a.mainPages.RemovePage("submenu2") + a.showInfo("数据目录已设置为 " + a.ctx.DataDir) + }) + + formView.AddButton("取消", func() { + a.mainPages.RemovePage("submenu2") + }) + + a.mainPages.AddPage("submenu2", formView, true, true) + a.SetFocus(formView) +} + +// selectAccountSelected 处理切换账号菜单项的选择事件 +func (a *App) selectAccountSelected(i *menu.Item) { + // 创建子菜单 + subMenu := menu.NewSubMenu("切换账号") + + // 添加微信进程 + instances := a.m.wechat.GetWeChatInstances() + if len(instances) > 0 { + // 添加实例标题 + subMenu.AddItem(&menu.Item{ + Index: 0, + Name: "--- 微信进程 ---", + Description: "", + Hidden: false, + Selected: nil, + }) + + // 添加实例列表 + for idx, instance := range instances { + // 创建一个实例描述 + description := fmt.Sprintf("版本: %s 目录: %s", instance.FullVersion, instance.DataDir) + + // 标记当前选中的实例 + name := fmt.Sprintf("%s [%d]", instance.Name, instance.PID) + if a.ctx.Current != nil && a.ctx.Current.PID == instance.PID { + name = name + " [当前]" + } + + // 创建菜单项 + instanceItem := &menu.Item{ + Index: idx + 1, + Name: name, + Description: description, + Hidden: false, + Selected: func(instance *wechat.Account) func(*menu.Item) { + return func(*menu.Item) { + // 如果是当前账号,则无需切换 + if a.ctx.Current != nil && a.ctx.Current.PID == instance.PID { + a.mainPages.RemovePage("submenu") + a.showInfo("已经是当前账号") + return + } + + // 显示切换中的模态框 + modal := tview.NewModal().SetText("正在切换账号...") + a.mainPages.AddPage("modal", modal, true, true) + a.SetFocus(modal) + + // 在后台执行切换操作 + go func() { + err := a.m.Switch(instance, "") + + // 在主线程中更新UI + a.QueueUpdateDraw(func() { + a.mainPages.RemovePage("modal") + a.mainPages.RemovePage("submenu") + + if err != nil { + // 切换失败 + a.showError(fmt.Errorf("切换账号失败: %v", err)) + } else { + // 切换成功 + a.showInfo("切换账号成功") + // 更新菜单状态 + a.updateMenuItemsState() + } + }) + }() + } + }(instance), + } + subMenu.AddItem(instanceItem) + } + } + + // 添加历史账号 + if len(a.ctx.History) > 0 { + // 添加历史账号标题 + subMenu.AddItem(&menu.Item{ + Index: 100, + Name: "--- 历史账号 ---", + Description: "", + Hidden: false, + Selected: nil, + }) + + // 添加历史账号列表 + idx := 101 + for account, hist := range a.ctx.History { + // 创建一个账号描述 + description := fmt.Sprintf("版本: %s 目录: %s", hist.FullVersion, hist.DataDir) + + // 标记当前选中的账号 + name := account + if name == "" { + name = filepath.Base(hist.DataDir) + } + if a.ctx.DataDir == hist.DataDir { + name = name + " [当前]" + } + + // 创建菜单项 + histItem := &menu.Item{ + Index: idx, + Name: name, + Description: description, + Hidden: false, + Selected: func(account string) func(*menu.Item) { + return func(*menu.Item) { + // 如果是当前账号,则无需切换 + if a.ctx.Current != nil && a.ctx.DataDir == a.ctx.History[account].DataDir { + a.mainPages.RemovePage("submenu") + a.showInfo("已经是当前账号") + return + } + + // 显示切换中的模态框 + modal := tview.NewModal().SetText("正在切换账号...") + a.mainPages.AddPage("modal", modal, true, true) + a.SetFocus(modal) + + // 在后台执行切换操作 + go func() { + err := a.m.Switch(nil, account) + + // 在主线程中更新UI + a.QueueUpdateDraw(func() { + a.mainPages.RemovePage("modal") + a.mainPages.RemovePage("submenu") + + if err != nil { + // 切换失败 + a.showError(fmt.Errorf("切换账号失败: %v", err)) + } else { + // 切换成功 + a.showInfo("切换账号成功") + // 更新菜单状态 + a.updateMenuItemsState() + } + }) + }() + } + }(account), + } + idx++ + subMenu.AddItem(histItem) + } + } + + // 如果没有账号可选择 + if len(a.ctx.History) == 0 && len(instances) == 0 { + subMenu.AddItem(&menu.Item{ + Index: 1, + Name: "无可用账号", + Description: "未检测到微信进程或历史账号", + Hidden: false, + Selected: nil, + }) + } + + // 显示子菜单 + a.mainPages.AddPage("submenu", subMenu, true, true) + a.SetFocus(subMenu) } // showModal 显示一个模态对话框 diff --git a/internal/chatlog/ctx/context.go b/internal/chatlog/ctx/context.go index a97b833..fc7ee2d 100644 --- a/internal/chatlog/ctx/context.go +++ b/internal/chatlog/ctx/context.go @@ -2,6 +2,7 @@ package ctx import ( "sync" + "time" "github.com/sjzar/chatlog/internal/chatlog/conf" "github.com/sjzar/chatlog/internal/wechat" @@ -33,6 +34,10 @@ type Context struct { HTTPEnabled bool HTTPAddr string + // 自动解密 + AutoDecrypt bool + LastSession time.Time + // 当前选中的微信实例 Current *wechat.Account PID int @@ -63,6 +68,10 @@ func (c *Context) loadConfig() { func (c *Context) SwitchHistory(account string) { c.mu.Lock() defer c.mu.Unlock() + c.Current = nil + c.PID = 0 + c.ExePath = "" + c.Status = "" history, ok := c.History[account] if ok { c.Account = history.Account @@ -153,6 +162,13 @@ func (c *Context) SetDataDir(dir string) { c.Refresh() } +func (c *Context) SetAutoDecrypt(enabled bool) { + c.mu.Lock() + defer c.mu.Unlock() + c.AutoDecrypt = enabled + c.UpdateConfig() +} + // 更新配置 func (c *Context) UpdateConfig() { pconf := conf.ProcessConfig{ diff --git a/internal/chatlog/http/service.go b/internal/chatlog/http/service.go index 08d04c1..8be622a 100644 --- a/internal/chatlog/http/service.go +++ b/internal/chatlog/http/service.go @@ -84,11 +84,12 @@ func (s *Service) Stop() error { } // 使用超时上下文优雅关闭 - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() if err := s.server.Shutdown(ctx); err != nil { - return errors.HTTPShutDown(err) + log.Debug().Err(err).Msg("Failed to shutdown HTTP server") + return nil } log.Info().Msg("HTTP server stopped") diff --git a/internal/chatlog/manager.go b/internal/chatlog/manager.go index 838f72f..d0b7766 100644 --- a/internal/chatlog/manager.go +++ b/internal/chatlog/manager.go @@ -6,12 +6,14 @@ import ( "path/filepath" "strings" + "github.com/rs/zerolog/log" "github.com/sjzar/chatlog/internal/chatlog/conf" "github.com/sjzar/chatlog/internal/chatlog/ctx" "github.com/sjzar/chatlog/internal/chatlog/database" "github.com/sjzar/chatlog/internal/chatlog/http" "github.com/sjzar/chatlog/internal/chatlog/mcp" "github.com/sjzar/chatlog/internal/chatlog/wechat" + iwechat "github.com/sjzar/chatlog/internal/wechat" "github.com/sjzar/chatlog/pkg/util" "github.com/sjzar/chatlog/pkg/util/dat2img" ) @@ -79,6 +81,33 @@ func (m *Manager) Run() error { return nil } +func (m *Manager) Switch(info *iwechat.Account, history string) error { + if m.ctx.AutoDecrypt { + if err := m.StopAutoDecrypt(); err != nil { + return err + } + } + if m.ctx.HTTPEnabled { + if err := m.stopService(); err != nil { + return err + } + } + if info != nil { + m.ctx.SwitchCurrent(info) + } else { + m.ctx.SwitchHistory(history) + } + + if m.ctx.HTTPEnabled { + // 启动HTTP服务 + if err := m.StartService(); err != nil { + log.Info().Err(err).Msg("启动服务失败") + m.StopService() + } + } + return nil +} + func (m *Manager) StartService() error { // 按依赖顺序启动服务 @@ -109,6 +138,17 @@ func (m *Manager) StartService() error { } func (m *Manager) StopService() error { + if err := m.stopService(); err != nil { + return err + } + + // 更新状态 + m.ctx.SetHTTPEnabled(false) + + return nil +} + +func (m *Manager) stopService() error { // 按依赖的反序停止服务 var errs []error @@ -124,9 +164,6 @@ func (m *Manager) StopService() error { errs = append(errs, err) } - // 更新状态 - m.ctx.SetHTTPEnabled(false) - // 如果有错误,返回第一个错误 if len(errs) > 0 { return errs[0] @@ -138,7 +175,7 @@ func (m *Manager) StopService() error { func (m *Manager) SetHTTPAddr(text string) error { var addr string if util.IsNumeric(text) { - addr = fmt.Sprintf("0.0.0.0:%s", text) + addr = fmt.Sprintf("127.0.0.1:%s", text) } else if strings.HasPrefix(text, "http://") { addr = strings.TrimPrefix(text, "http://") } else if strings.HasPrefix(text, "https://") { @@ -175,7 +212,7 @@ func (m *Manager) DecryptDBFiles() error { m.ctx.WorkDir = util.DefaultWorkDir(m.ctx.Account) } - if err := m.wechat.DecryptDBFiles(m.ctx.DataDir, m.ctx.WorkDir, m.ctx.DataKey, m.ctx.Platform, m.ctx.Version); err != nil { + if err := m.wechat.DecryptDBFiles(); err != nil { return err } m.ctx.Refresh() @@ -183,6 +220,48 @@ func (m *Manager) DecryptDBFiles() error { return nil } +func (m *Manager) StartAutoDecrypt() error { + if m.ctx.DataKey == "" || m.ctx.DataDir == "" { + return fmt.Errorf("请先获取密钥") + } + if m.ctx.WorkDir == "" { + return fmt.Errorf("请先执行解密数据") + } + + if err := m.wechat.StartAutoDecrypt(); err != nil { + return err + } + + m.ctx.SetAutoDecrypt(true) + return nil +} + +func (m *Manager) StopAutoDecrypt() error { + if err := m.wechat.StopAutoDecrypt(); err != nil { + return err + } + + m.ctx.SetAutoDecrypt(false) + return nil +} + +func (m *Manager) RefreshSession() error { + if m.db.GetDB() == nil { + if err := m.db.Start(); err != nil { + return err + } + } + resp, err := m.db.GetSessions("", 1, 0) + if err != nil { + return err + } + if len(resp.Items) == 0 { + return nil + } + m.ctx.LastSession = resp.Items[0].NTime + return nil +} + func (m *Manager) CommandKey(pid int) (string, error) { instances := m.wechat.GetWeChatInstances() if len(instances) == 0 { @@ -216,8 +295,12 @@ func (m *Manager) CommandDecrypt(dataDir string, workDir string, key string, pla if workDir == "" { workDir = util.DefaultWorkDir(filepath.Base(filepath.Dir(dataDir))) } - - if err := m.wechat.DecryptDBFiles(dataDir, workDir, key, platform, version); err != nil { + m.ctx.DataDir = dataDir + m.ctx.WorkDir = workDir + m.ctx.DataKey = key + m.ctx.Platform = platform + m.ctx.Version = version + if err := m.wechat.DecryptDBFiles(); err != nil { return err } diff --git a/internal/chatlog/wechat/service.go b/internal/chatlog/wechat/service.go index cdba4b8..8d7cf0e 100644 --- a/internal/chatlog/wechat/service.go +++ b/internal/chatlog/wechat/service.go @@ -5,24 +5,38 @@ import ( "fmt" "os" "path/filepath" - "strings" + "sync" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/rs/zerolog/log" "github.com/sjzar/chatlog/internal/chatlog/ctx" "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/wechat" "github.com/sjzar/chatlog/internal/wechat/decrypt" + "github.com/sjzar/chatlog/pkg/filemonitor" "github.com/sjzar/chatlog/pkg/util" +) - "github.com/rs/zerolog/log" +var ( + DebounceTime = 1 * time.Second + MaxWaitTime = 10 * time.Second ) type Service struct { - ctx *ctx.Context + ctx *ctx.Context + lastEvents map[string]time.Time + pendingActions map[string]bool + mutex sync.Mutex + fm *filemonitor.FileMonitor } func NewService(ctx *ctx.Context) *Service { return &Service{ - ctx: ctx, + ctx: ctx, + lastEvents: make(map[string]time.Time), + pendingActions: make(map[string]bool), } } @@ -46,90 +60,128 @@ func (s *Service) GetDataKey(info *wechat.Account) (string, error) { return key, nil } -// FindDBFiles finds all .db files in the specified directory -func (s *Service) FindDBFiles(rootDir string, recursive bool) ([]string, error) { - // Check if directory exists - info, err := os.Stat(rootDir) +func (s *Service) StartAutoDecrypt() error { + dbGroup, err := filemonitor.NewFileGroup("wechat", s.ctx.DataDir, `.*\.db$`, []string{"fts"}) if err != nil { - return nil, fmt.Errorf("cannot access directory %s: %w", rootDir, err) + return err } + dbGroup.AddCallback(s.DecryptFileCallback) - if !info.IsDir() { - return nil, fmt.Errorf("%s is not a directory", rootDir) + s.fm = filemonitor.NewFileMonitor() + s.fm.AddGroup(dbGroup) + if err := s.fm.Start(); err != nil { + log.Debug().Err(err).Msg("failed to start file monitor") + return err } + return nil +} - var dbFiles []string - - // Define walk function - walkFunc := func(path string, info os.FileInfo, err error) error { - if err != nil { - // If a file or directory can't be accessed, log the error but continue - log.Err(err).Msgf("Warning: Cannot access %s", path) - return nil - } - - // If it's a directory and not the root directory, and we're not recursively searching, skip it - if info.IsDir() && path != rootDir && !recursive { - return filepath.SkipDir - } - - // Check if file extension is .db - if !info.IsDir() && strings.ToLower(filepath.Ext(path)) == ".db" { - dbFiles = append(dbFiles, path) +func (s *Service) StopAutoDecrypt() error { + if s.fm != nil { + if err := s.fm.Stop(); err != nil { + return err } + } + s.fm = nil + return nil +} +func (s *Service) DecryptFileCallback(event fsnotify.Event) error { + if event.Op.Has(fsnotify.Chmod) || !event.Op.Has(fsnotify.Write) { return nil } - // Start traversal - err = filepath.Walk(rootDir, walkFunc) - if err != nil { - return nil, fmt.Errorf("error traversing directory: %w", err) + s.mutex.Lock() + s.lastEvents[event.Name] = time.Now() + + if !s.pendingActions[event.Name] { + s.pendingActions[event.Name] = true + s.mutex.Unlock() + go s.waitAndProcess(event.Name) + } else { + s.mutex.Unlock() } - if len(dbFiles) == 0 { - return nil, fmt.Errorf("no .db files found") - } - - return dbFiles, nil + return nil } -func (s *Service) DecryptDBFiles(dataDir string, workDir string, key string, platform string, version int) error { +func (s *Service) waitAndProcess(dbFile string) { + start := time.Now() + for { + time.Sleep(DebounceTime) - ctx := context.Background() + s.mutex.Lock() + lastEventTime := s.lastEvents[dbFile] + elapsed := time.Since(lastEventTime) + totalElapsed := time.Since(start) - dbfiles, err := s.FindDBFiles(dataDir, true) + if elapsed >= DebounceTime || totalElapsed >= MaxWaitTime { + s.pendingActions[dbFile] = false + s.mutex.Unlock() + + log.Debug().Msgf("Processing file: %s", dbFile) + s.DecryptDBFile(dbFile) + return + } + s.mutex.Unlock() + } +} + +func (s *Service) DecryptDBFile(dbFile string) error { + + decryptor, err := decrypt.NewDecryptor(s.ctx.Platform, s.ctx.Version) if err != nil { return err } - decryptor, err := decrypt.NewDecryptor(platform, version) - if err != nil { + output := filepath.Join(s.ctx.WorkDir, dbFile[len(s.ctx.DataDir):]) + if err := util.PrepareDir(filepath.Dir(output)); err != nil { return err } - for _, dbfile := range dbfiles { - output := filepath.Join(workDir, dbfile[len(dataDir):]) - if err := util.PrepareDir(filepath.Dir(output)); err != nil { - return err + outputTemp := output + ".tmp" + outputFile, err := os.Create(outputTemp) + if err != nil { + return fmt.Errorf("failed to create output file: %v", err) + } + defer func() { + outputFile.Close() + if err := os.Rename(outputTemp, output); err != nil { + log.Debug().Err(err).Msgf("failed to rename %s to %s", outputTemp, output) } + }() - outputFile, err := os.Create(output) - if err != nil { - return fmt.Errorf("failed to create output file: %v", err) - } - defer outputFile.Close() - - if err := decryptor.Decrypt(ctx, dbfile, key, outputFile); err != nil { - log.Err(err).Msgf("failed to decrypt %s", dbfile) - if err == errors.ErrAlreadyDecrypted { - if data, err := os.ReadFile(dbfile); err == nil { - outputFile.Write(data) - } - continue + if err := decryptor.Decrypt(context.Background(), dbFile, s.ctx.DataKey, outputFile); err != nil { + if err == errors.ErrAlreadyDecrypted { + if data, err := os.ReadFile(dbFile); err == nil { + outputFile.Write(data) } + return nil + } + log.Err(err).Msgf("failed to decrypt %s", dbFile) + return err + } + + log.Debug().Msgf("Decrypted %s to %s", dbFile, output) + + return nil +} + +func (s *Service) DecryptDBFiles() error { + dbGroup, err := filemonitor.NewFileGroup("wechat", s.ctx.DataDir, `.*\.db$`, []string{"fts"}) + if err != nil { + return err + } + + dbFiles, err := dbGroup.List() + if err != nil { + return err + } + + for _, dbFile := range dbFiles { + if err := s.DecryptDBFile(dbFile); err != nil { + log.Debug().Msgf("DecryptDBFile %s failed: %v", dbFile, err) continue - // return err } } diff --git a/internal/errors/wechatdb_errors.go b/internal/errors/wechatdb_errors.go index 42b6745..d6a2e46 100644 --- a/internal/errors/wechatdb_errors.go +++ b/internal/errors/wechatdb_errors.go @@ -60,3 +60,7 @@ func ContactNotFound(key string) *Error { func InitCacheFailed(cause error) *Error { return New(cause, http.StatusInternalServerError, "init cache failed").WithStack() } + +func FileGroupNotFound(name string) *Error { + return Newf(nil, http.StatusNotFound, "file group not found: %s", name).WithStack() +} diff --git a/internal/ui/form/form.go b/internal/ui/form/form.go new file mode 100644 index 0000000..6f94d69 --- /dev/null +++ b/internal/ui/form/form.go @@ -0,0 +1,258 @@ +package form + +import ( + "fmt" + + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" + "github.com/sjzar/chatlog/internal/ui/style" +) + +const ( + // DialogPadding dialog inner padding. + DialogPadding = 3 + + // DialogHelpHeight dialog help text height. + DialogHelpHeight = 1 + + // DialogMinWidth dialog min width. + DialogMinWidth = 40 + + // FormHeightOffset form height offset for border. + FormHeightOffset = 3 + + // 额外的宽度补偿,类似于 submenu 的 cmdWidthOffset + formWidthOffset = 10 +) + +// Form is a modal form component with a title, form fields, and help text. +type Form struct { + *tview.Box + title string + layout *tview.Flex + form *tview.Form + helpText *tview.TextView + width int + height int + cancelHandler func() + fields []formField // 存储字段信息以便重新计算宽度 +} + +// formField 存储表单字段的信息 +type formField struct { + label string + value string + fieldWidth int +} + +// NewForm creates a new form with the given title. +func NewForm(title string) *Form { + f := &Form{ + Box: tview.NewBox(), + title: title, + layout: tview.NewFlex().SetDirection(tview.FlexRow), + form: tview.NewForm(), + fields: make([]formField, 0), + } + + // 设置表单样式 + f.form.SetBorderPadding(1, 1, 1, 1) + f.form.SetBackgroundColor(style.DialogBgColor) + f.form.SetFieldBackgroundColor(style.BgColor) + f.form.SetFieldTextColor(style.FgColor) + f.form.SetButtonBackgroundColor(style.ButtonBgColor) + f.form.SetButtonTextColor(style.FgColor) + f.form.SetLabelColor(style.DialogFgColor) + f.form.SetButtonsAlign(tview.AlignCenter) + + // 创建帮助文本 + f.helpText = tview.NewTextView() + f.helpText.SetDynamicColors(true) + f.helpText.SetTextAlign(tview.AlignCenter) + f.helpText.SetTextColor(style.DialogFgColor) + f.helpText.SetBackgroundColor(style.DialogBgColor) + fmt.Fprintf(f.helpText, + "[%s::b]Tab[%s::b]: 导航 [%s::b]Enter[%s::b]: 选择 [%s::b]ESC[%s::b]: 返回", + style.GetColorHex(style.MenuBgColor), style.GetColorHex(style.PageHeaderFgColor), + style.GetColorHex(style.MenuBgColor), style.GetColorHex(style.PageHeaderFgColor), + style.GetColorHex(style.MenuBgColor), style.GetColorHex(style.PageHeaderFgColor), + ) + + // 创建布局 + formLayout := tview.NewFlex().SetDirection(tview.FlexColumn) + formLayout.AddItem(EmptyBoxSpace(style.DialogBgColor), 1, 0, false) + formLayout.AddItem(f.form, 0, 1, true) + formLayout.AddItem(EmptyBoxSpace(style.DialogBgColor), 1, 0, false) + + // 设置主布局 + f.layout.SetTitle(fmt.Sprintf("[::b]%s", f.title)) + f.layout.SetTitleColor(style.DialogFgColor) + f.layout.SetTitleAlign(tview.AlignCenter) + f.layout.SetBorder(true) + f.layout.SetBorderColor(style.DialogBorderColor) + f.layout.SetBackgroundColor(style.DialogBgColor) + + // 添加表单区域 + f.layout.AddItem(formLayout, 0, 1, true) + + // 添加帮助文本区域 + f.layout.AddItem(f.helpText, DialogHelpHeight, 0, false) + + return f +} + +// AddInputField adds an input field to the form. +func (f *Form) AddInputField(label, value string, fieldWidth int, accept func(textToCheck string, lastChar rune) bool, changed func(text string)) *Form { + // 存储字段信息 + f.fields = append(f.fields, formField{ + label: label, + value: value, + fieldWidth: fieldWidth, + }) + + // 添加输入字段到表单 + f.form.AddInputField(label, value, fieldWidth, accept, changed) + + // 更新表单尺寸 + f.recalculateSize() + + return f +} + +// AddButton adds a button to the form. +func (f *Form) AddButton(label string, selected func()) *Form { + f.form.AddButton(label, selected) + // 更新表单尺寸 + f.recalculateSize() + return f +} + +// AddCheckbox adds a checkbox to the form. +func (f *Form) AddCheckbox(label string, checked bool, changed func(checked bool)) *Form { + f.form.AddCheckbox(label, checked, changed) + // 更新表单尺寸 + f.recalculateSize() + return f +} + +// SetCancelFunc sets the function to be called when the form is cancelled. +func (f *Form) SetCancelFunc(handler func()) *Form { + f.cancelHandler = handler + return f +} + +// recalculateSize 重新计算表单尺寸 +func (f *Form) recalculateSize() { + // 计算表单项数量 + itemCount := f.form.GetFormItemCount() + + // 计算高度 - 每个表单项占2行,按钮区域至少占2行,再加上边框和帮助文本 + f.height = (itemCount * 2) + 2 + FormHeightOffset + DialogHelpHeight + + // 计算宽度 - 类似于 submenu 的实现 + maxLabelWidth := 0 + maxValueWidth := 0 + + // 遍历所有字段,找出最长的标签和值 + for _, field := range f.fields { + if len(field.label) > maxLabelWidth { + maxLabelWidth = len(field.label) + } + + // 对于值,使用字段宽度和实际值长度中的较大者 + valueWidth := field.fieldWidth + if len(field.value) > valueWidth { + valueWidth = len(field.value) + } + + if valueWidth > maxValueWidth { + maxValueWidth = valueWidth + } + } + + // 计算总宽度,类似于 submenu 的计算方式 + f.width = maxLabelWidth + maxValueWidth + formWidthOffset + + // 确保宽度不小于最小值 + if f.width < DialogMinWidth { + f.width = DialogMinWidth + } +} + +// Draw draws the form on the screen. +func (f *Form) Draw(screen tcell.Screen) { + // 在绘制前重新计算尺寸,确保尺寸是最新的 + f.recalculateSize() + + // 绘制 + f.Box.DrawForSubclass(screen, f) + f.layout.Draw(screen) +} + +// SetRect sets the position and size of the form. +func (f *Form) SetRect(x, y, width, height int) { + // 确保尺寸是最新的 + f.recalculateSize() + + // 类似于 submenu 的实现 + ws := (width - f.width) / 2 + hs := (height - f.height) / 2 + + // 确保不会超出屏幕 + if f.width > width { + ws = 0 + f.width = width - 1 + } + + if f.height > height { + hs = 0 + f.height = height - 1 + } + + // 设置表单位置 + f.Box.SetRect(x+ws, y+hs, f.width, f.height) + + // 获取内部矩形并设置布局 + x, y, width, height = f.Box.GetInnerRect() + f.layout.SetRect(x, y, width, height) +} + +// Focus is called when this primitive receives focus. +func (f *Form) Focus(delegate func(p tview.Primitive)) { + // 确保表单获得焦点 + if f.form != nil { + delegate(f.form) + } else { + // 如果表单为空,则让Box获得焦点 + delegate(f.Box) + } +} + +// HasFocus returns whether or not this primitive has focus. +func (f *Form) HasFocus() bool { + return f.form.HasFocus() +} + +// InputHandler returns the handler for this primitive. +func (f *Form) InputHandler() func(event *tcell.EventKey, setFocus func(p tview.Primitive)) { + return f.WrapInputHandler(func(event *tcell.EventKey, setFocus func(p tview.Primitive)) { + // ESC键处理 + if event.Key() == tcell.KeyEscape && f.cancelHandler != nil { + f.cancelHandler() + return + } + + // 将事件传递给表单 + if handler := f.form.InputHandler(); handler != nil { + handler(event, setFocus) + } + }) +} + +// EmptyBoxSpace creates an empty box with the specified background color. +func EmptyBoxSpace(bgColor tcell.Color) *tview.Box { + box := tview.NewBox() + box.SetBackgroundColor(bgColor) + box.SetBorder(false) + return box +} diff --git a/internal/ui/infobar/infobar.go b/internal/ui/infobar/infobar.go index 2395f76..05aa648 100644 --- a/internal/ui/infobar/infobar.go +++ b/internal/ui/infobar/infobar.go @@ -15,13 +15,14 @@ const ( // InfoBarViewHeight info bar height. const ( - InfoBarViewHeight = 6 + InfoBarViewHeight = 7 accountRow = 0 - pidRow = 1 - statusRow = 2 - dataUsageRow = 3 - workUsageRow = 4 - httpServerRow = 5 + statusRow = 1 + platformRow = 2 + sessionRow = 3 + dataUsageRow = 4 + workUsageRow = 5 + httpServerRow = 6 // 列索引 labelCol1 = 0 // 第一列标签 @@ -43,7 +44,7 @@ func New() *InfoBar { table := tview.NewTable() headerColor := style.InfoBarItemFgColor - // Account 和 Version 行 + // Account 和 PID 行 table.SetCell( accountRow, labelCol1, @@ -54,26 +55,11 @@ func New() *InfoBar { table.SetCell( accountRow, labelCol2, - tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Version:")), + tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "PID:")), ) table.SetCell(accountRow, valueCol2, tview.NewTableCell("")) - // PID 和 ExePath 行 - table.SetCell( - pidRow, - labelCol1, - tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "PID:")), - ) - table.SetCell(pidRow, valueCol1, tview.NewTableCell("")) - - table.SetCell( - pidRow, - labelCol2, - tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "ExePath:")), - ) - table.SetCell(pidRow, valueCol2, tview.NewTableCell("")) - - // Status 和 Key 行 + // Status 和 ExePath 行 table.SetCell( statusRow, labelCol1, @@ -84,10 +70,40 @@ func New() *InfoBar { table.SetCell( statusRow, labelCol2, - tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Data Key:")), + tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "ExePath:")), ) table.SetCell(statusRow, valueCol2, tview.NewTableCell("")) + // Platform 和 Version 行 + table.SetCell( + platformRow, + labelCol1, + tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Platform:")), + ) + table.SetCell(platformRow, valueCol1, tview.NewTableCell("")) + + table.SetCell( + platformRow, + labelCol2, + tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Version:")), + ) + table.SetCell(platformRow, valueCol2, tview.NewTableCell("")) + + // Session 和 Data Key 行 + table.SetCell( + sessionRow, + labelCol1, + tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Session:")), + ) + table.SetCell(sessionRow, valueCol1, tview.NewTableCell("")) + + table.SetCell( + sessionRow, + labelCol2, + tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Data Key:")), + ) + table.SetCell(sessionRow, valueCol2, tview.NewTableCell("")) + // Data Usage 和 Data Dir 行 table.SetCell( dataUsageRow, @@ -126,6 +142,13 @@ func New() *InfoBar { ) table.SetCell(httpServerRow, valueCol1, tview.NewTableCell("")) + table.SetCell( + httpServerRow, + labelCol2, + tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Auto Decrypt:")), + ) + table.SetCell(httpServerRow, valueCol2, tview.NewTableCell("")) + // infobar infoBar := &InfoBar{ Box: tview.NewBox(), @@ -141,17 +164,25 @@ func (info *InfoBar) UpdateAccount(account string) { } func (info *InfoBar) UpdateBasicInfo(pid int, version string, exePath string) { - info.table.GetCell(pidRow, valueCol1).SetText(fmt.Sprintf("%d", pid)) - info.table.GetCell(pidRow, valueCol2).SetText(exePath) - info.table.GetCell(accountRow, valueCol2).SetText(version) + info.table.GetCell(accountRow, valueCol2).SetText(fmt.Sprintf("%d", pid)) + info.table.GetCell(statusRow, valueCol2).SetText(exePath) + info.table.GetCell(platformRow, valueCol2).SetText(version) } func (info *InfoBar) UpdateStatus(status string) { info.table.GetCell(statusRow, valueCol1).SetText(status) } +func (info *InfoBar) UpdatePlatform(text string) { + info.table.GetCell(platformRow, valueCol1).SetText(text) +} + +func (info *InfoBar) UpdateSession(text string) { + info.table.GetCell(sessionRow, valueCol1).SetText(text) +} + func (info *InfoBar) UpdateDataKey(key string) { - info.table.GetCell(statusRow, valueCol2).SetText(key) + info.table.GetCell(sessionRow, valueCol2).SetText(key) } func (info *InfoBar) UpdateDataUsageDir(dataUsage string, dataDir string) { @@ -169,6 +200,11 @@ func (info *InfoBar) UpdateHTTPServer(server string) { info.table.GetCell(httpServerRow, valueCol1).SetText(server) } +// UpdateAutoDecrypt updates Auto Decrypt value. +func (info *InfoBar) UpdateAutoDecrypt(text string) { + info.table.GetCell(httpServerRow, valueCol2).SetText(text) +} + // Draw draws this primitive onto the screen. func (info *InfoBar) Draw(screen tcell.Screen) { info.Box.DrawForSubclass(screen, info) diff --git a/internal/wechatdb/datasource/darwinv3/datasource.go b/internal/wechatdb/datasource/darwinv3/datasource.go index e754e39..3171f35 100644 --- a/internal/wechatdb/datasource/darwinv3/datasource.go +++ b/internal/wechatdb/datasource/darwinv3/datasource.go @@ -3,87 +3,127 @@ package darwinv3 import ( "context" "crypto/md5" - "database/sql" "encoding/hex" "fmt" "strings" "time" + "github.com/fsnotify/fsnotify" _ "github.com/mattn/go-sqlite3" "github.com/rs/zerolog/log" "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" - "github.com/sjzar/chatlog/pkg/util" + "github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm" ) const ( - MessageFilePattern = "^msg_([0-9]?[0-9])?\\.db$" - ContactFilePattern = "^wccontact_new2\\.db$" - ChatRoomFilePattern = "^group_new\\.db$" - SessionFilePattern = "^session_new\\.db$" - MediaFilePattern = "^hldata\\.db$" + Message = "message" + Contact = "contact" + ChatRoom = "chatroom" + Session = "session" + Media = "media" ) -type DataSource struct { - path string - messageDbs []*sql.DB - contactDb *sql.DB - chatRoomDb *sql.DB - sessionDb *sql.DB - mediaDb *sql.DB +var Groups = []dbm.Group{ + { + Name: Message, + Pattern: `^msg_([0-9]?[0-9])?\.db$`, + BlackList: []string{}, + }, + { + Name: Contact, + Pattern: `^wccontact_new2\.db$`, + BlackList: []string{}, + }, + { + Name: ChatRoom, + Pattern: `group_new\.db$`, + BlackList: []string{}, + }, + { + Name: Session, + Pattern: `^session_new\.db$`, + BlackList: []string{}, + }, + { + Name: Media, + Pattern: `^hldata\.db$`, + BlackList: []string{}, + }, +} - talkerDBMap map[string]*sql.DB +type DataSource struct { + path string + dbm *dbm.DBManager + + talkerDBMap map[string]string user2DisplayName map[string]string } func New(path string) (*DataSource, error) { ds := &DataSource{ path: path, - messageDbs: make([]*sql.DB, 0), - talkerDBMap: make(map[string]*sql.DB), + dbm: dbm.NewDBManager(path), + talkerDBMap: make(map[string]string), user2DisplayName: make(map[string]string), } - if err := ds.initMessageDbs(path); err != nil { + for _, g := range Groups { + ds.dbm.AddGroup(g) + } + + if err := ds.dbm.Start(); err != nil { + return nil, err + } + + if err := ds.initMessageDbs(); err != nil { return nil, errors.DBInitFailed(err) } - if err := ds.initContactDb(path); err != nil { - return nil, errors.DBInitFailed(err) - } - if err := ds.initChatRoomDb(path); err != nil { - return nil, errors.DBInitFailed(err) - } - if err := ds.initSessionDb(path); err != nil { - return nil, errors.DBInitFailed(err) - } - if err := ds.initMediaDb(path); err != nil { + if err := ds.initChatRoomDb(); err != nil { return nil, errors.DBInitFailed(err) } + ds.dbm.AddCallback(Message, func(event fsnotify.Event) error { + if !event.Op.Has(fsnotify.Create) { + return nil + } + if err := ds.initMessageDbs(); err != nil { + log.Err(err).Msgf("Failed to reinitialize message DBs: %s", event.Name) + } + return nil + }) + ds.dbm.AddCallback(ChatRoom, func(event fsnotify.Event) error { + if !event.Op.Has(fsnotify.Create) { + return nil + } + if err := ds.initChatRoomDb(); err != nil { + log.Err(err).Msgf("Failed to reinitialize chatroom DB: %s", event.Name) + } + return nil + }) + return ds, nil } -func (ds *DataSource) initMessageDbs(path string) error { +func (ds *DataSource) SetCallback(name string, callback func(event fsnotify.Event) error) error { + return ds.dbm.AddCallback(name, callback) +} - files, err := util.FindFilesWithPatterns(path, MessageFilePattern, true) +func (ds *DataSource) initMessageDbs() error { + + dbPaths, err := ds.dbm.GetDBPath(Message) if err != nil { - return errors.DBFileNotFound(path, MessageFilePattern, err) + return err } - - if len(files) == 0 { - return errors.DBFileNotFound(path, MessageFilePattern, nil) - } - // 处理每个数据库文件 - for _, filePath := range files { - // 连接数据库 - db, err := sql.Open("sqlite3", filePath) + talkerDBMap := make(map[string]string) + for _, filePath := range dbPaths { + db, err := ds.dbm.OpenDB(filePath) if err != nil { - log.Err(err).Msgf("连接数据库 %s 失败", filePath) + log.Err(err).Msgf("获取数据库 %s 失败", filePath) continue } - ds.messageDbs = append(ds.messageDbs, db) // 获取所有表名 rows, err := db.Query("SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Chat_%'") @@ -104,96 +144,42 @@ func (ds *DataSource) initMessageDbs(path string) error { if talkerMd5 == "" { continue } - ds.talkerDBMap[talkerMd5] = db + talkerDBMap[talkerMd5] = filePath } rows.Close() - } + ds.talkerDBMap = talkerDBMap return nil } -func (ds *DataSource) initContactDb(path string) error { - - files, err := util.FindFilesWithPatterns(path, ContactFilePattern, true) +func (ds *DataSource) initChatRoomDb() error { + db, err := ds.dbm.GetDB(ChatRoom) if err != nil { - return errors.DBFileNotFound(path, ContactFilePattern, err) + return err } - if len(files) == 0 { - return errors.DBFileNotFound(path, ContactFilePattern, nil) - } - - ds.contactDb, err = sql.Open("sqlite3", files[0]) + rows, err := db.Query("SELECT m_nsUsrName, IFNULL(nickname,\"\") FROM GroupMember") if err != nil { - return errors.DBConnectFailed(files[0], err) - } - - return nil -} - -func (ds *DataSource) initChatRoomDb(path string) error { - files, err := util.FindFilesWithPatterns(path, ChatRoomFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, ChatRoomFilePattern, err) - } - if len(files) == 0 { - return errors.DBFileNotFound(path, ChatRoomFilePattern, nil) - } - ds.chatRoomDb, err = sql.Open("sqlite3", files[0]) - if err != nil { - return errors.DBConnectFailed(files[0], err) - } - - rows, err := ds.chatRoomDb.Query("SELECT m_nsUsrName, IFNULL(nickname,\"\") FROM GroupMember") - if err != nil { - log.Err(err).Msgf("数据库 %s 获取群聊成员失败", files[0]) + log.Err(err).Msg("获取群聊成员失败") return nil } + user2DisplayName := make(map[string]string) for rows.Next() { var user string var nickName string if err := rows.Scan(&user, &nickName); err != nil { - log.Err(err).Msgf("数据库 %s 扫描表名失败", files[0]) + log.Err(err).Msg("扫描表名失败") continue } - ds.user2DisplayName[user] = nickName + user2DisplayName[user] = nickName } rows.Close() + ds.user2DisplayName = user2DisplayName return nil } -func (ds *DataSource) initSessionDb(path string) error { - files, err := util.FindFilesWithPatterns(path, SessionFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, SessionFilePattern, err) - } - if len(files) == 0 { - return errors.DBFileNotFound(path, SessionFilePattern, nil) - } - ds.sessionDb, err = sql.Open("sqlite3", files[0]) - if err != nil { - return errors.DBConnectFailed(files[0], err) - } - return nil -} - -func (ds *DataSource) initMediaDb(path string) error { - files, err := util.FindFilesWithPatterns(path, MediaFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, MediaFilePattern, err) - } - if len(files) == 0 { - return errors.DBFileNotFound(path, MediaFilePattern, nil) - } - ds.mediaDb, err = sql.Open("sqlite3", files[0]) - if err != nil { - return errors.DBConnectFailed(files[0], err) - } - return nil -} - // GetMessages 实现获取消息的方法 func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { // 在 darwinv3 中,每个联系人/群聊的消息存储在单独的表中,表名为 Chat_md5(talker) @@ -204,10 +190,14 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T _talkerMd5Bytes := md5.Sum([]byte(talker)) talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:]) - db, ok := ds.talkerDBMap[talkerMd5] + dbPath, ok := ds.talkerDBMap[talkerMd5] if !ok { return nil, errors.TalkerNotFound(talker) } + db, err := ds.dbm.OpenDB(dbPath) + if err != nil { + return nil, err + } tableName := fmt.Sprintf("Chat_%s", talkerMd5) // 构建查询条件 @@ -297,7 +287,11 @@ func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset } // 执行查询 - rows, err := ds.contactDb.QueryContext(ctx, query, args...) + db, err := ds.dbm.GetDB(Contact) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -351,7 +345,11 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse } // 执行查询 - rows, err := ds.chatRoomDb.QueryContext(ctx, query, args...) + db, err := ds.dbm.GetDB(ChatRoom) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -380,7 +378,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse contacts, err := ds.GetContacts(ctx, key, 1, 0) if err == nil && len(contacts) > 0 && strings.HasSuffix(contacts[0].UserName, "@chatroom") { // 再次尝试通过用户名查找群聊 - rows, err := ds.chatRoomDb.QueryContext(ctx, + rows, err := db.QueryContext(ctx, `SELECT IFNULL(m_nsUsrName,""), IFNULL(nickname,""), IFNULL(m_nsRemark,""), IFNULL(m_nsChatRoomMemList,""), IFNULL(m_nsChatRoomAdminList,"") FROM GroupContact WHERE m_nsUsrName = ?`, @@ -448,7 +446,11 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset } // 执行查询 - rows, err := ds.sessionDb.QueryContext(ctx, query, args...) + db, err := ds.dbm.GetDB(Session) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -506,7 +508,11 @@ WHERE r.mediaMd5 = ?` args := []interface{}{key} // 执行查询 - rows, err := ds.mediaDb.QueryContext(ctx, query, args...) + db, err := ds.dbm.GetDB(Media) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -541,46 +547,5 @@ WHERE // Close 实现关闭数据库连接的方法 func (ds *DataSource) Close() error { - var errs []error - - // 关闭消息数据库连接 - for _, db := range ds.messageDbs { - if err := db.Close(); err != nil { - errs = append(errs, err) - } - } - - // 关闭联系人数据库连接 - if ds.contactDb != nil { - if err := ds.contactDb.Close(); err != nil { - errs = append(errs, err) - } - } - - // 关闭群聊数据库连接 - if ds.chatRoomDb != nil { - if err := ds.chatRoomDb.Close(); err != nil { - errs = append(errs, err) - } - } - - // 关闭会话数据库连接 - if ds.sessionDb != nil { - if err := ds.sessionDb.Close(); err != nil { - errs = append(errs, err) - } - } - - // 关闭媒体数据库连接 - if ds.mediaDb != nil { - if err := ds.mediaDb.Close(); err != nil { - errs = append(errs, err) - } - } - - if len(errs) > 0 { - return errors.DBCloseFailed(errs[0]) - } - - return nil + return ds.dbm.Close() } diff --git a/internal/wechatdb/datasource/datasource.go b/internal/wechatdb/datasource/datasource.go index 327362a..56d68bd 100644 --- a/internal/wechatdb/datasource/datasource.go +++ b/internal/wechatdb/datasource/datasource.go @@ -4,6 +4,8 @@ import ( "context" "time" + "github.com/fsnotify/fsnotify" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/internal/wechatdb/datasource/darwinv3" @@ -28,6 +30,9 @@ type DataSource interface { // 媒体 GetMedia(ctx context.Context, _type string, key string) (*model.Media, error) + // 设置回调函数 + SetCallback(name string, callback func(event fsnotify.Event) error) error + Close() error } diff --git a/internal/wechatdb/datasource/dbm/dbm.go b/internal/wechatdb/datasource/dbm/dbm.go new file mode 100644 index 0000000..fc4efd9 --- /dev/null +++ b/internal/wechatdb/datasource/dbm/dbm.go @@ -0,0 +1,170 @@ +package dbm + +import ( + "database/sql" + "runtime" + "sync" + "time" + + "github.com/fsnotify/fsnotify" + _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog/log" + + "github.com/sjzar/chatlog/internal/errors" + "github.com/sjzar/chatlog/pkg/filecopy" + "github.com/sjzar/chatlog/pkg/filemonitor" +) + +type DBManager struct { + path string + fm *filemonitor.FileMonitor + fgs map[string]*filemonitor.FileGroup + dbs map[string]*sql.DB + dbPaths map[string][]string + mutex sync.RWMutex +} + +func NewDBManager(path string) *DBManager { + return &DBManager{ + path: path, + fm: filemonitor.NewFileMonitor(), + fgs: make(map[string]*filemonitor.FileGroup), + dbs: make(map[string]*sql.DB), + dbPaths: make(map[string][]string), + } +} + +func (d *DBManager) AddGroup(g Group) error { + fg, err := filemonitor.NewFileGroup(g.Name, d.path, g.Pattern, g.BlackList) + if err != nil { + return err + } + fg.AddCallback(d.Callback) + d.fm.AddGroup(fg) + d.mutex.Lock() + d.fgs[g.Name] = fg + d.mutex.Unlock() + return nil +} + +func (d *DBManager) AddCallback(name string, callback func(event fsnotify.Event) error) error { + d.mutex.RLock() + fg, ok := d.fgs[name] + d.mutex.RUnlock() + if !ok { + return errors.FileGroupNotFound(name) + } + fg.AddCallback(callback) + return nil +} + +func (d *DBManager) GetDB(name string) (*sql.DB, error) { + dbPaths, err := d.GetDBPath(name) + if err != nil { + return nil, err + } + return d.OpenDB(dbPaths[0]) +} + +func (d *DBManager) GetDBs(name string) ([]*sql.DB, error) { + dbPaths, err := d.GetDBPath(name) + if err != nil { + return nil, err + } + dbs := make([]*sql.DB, 0) + for _, file := range dbPaths { + db, err := d.OpenDB(file) + if err != nil { + return nil, err + } + dbs = append(dbs, db) + } + return dbs, nil +} + +func (d *DBManager) GetDBPath(name string) ([]string, error) { + d.mutex.RLock() + dbPaths, ok := d.dbPaths[name] + d.mutex.RUnlock() + if !ok { + d.mutex.RLock() + fg, ok := d.fgs[name] + d.mutex.RUnlock() + if !ok { + return nil, errors.FileGroupNotFound(name) + } + list, err := fg.List() + if err != nil { + return nil, errors.DBFileNotFound(d.path, fg.PatternStr, err) + } + if len(list) == 0 { + return nil, errors.DBFileNotFound(d.path, fg.PatternStr, nil) + } + dbPaths = list + d.mutex.Lock() + d.dbPaths[name] = dbPaths + d.mutex.Unlock() + } + return dbPaths, nil +} + +func (d *DBManager) OpenDB(path string) (*sql.DB, error) { + d.mutex.RLock() + db, ok := d.dbs[path] + d.mutex.RUnlock() + if ok { + return db, nil + } + var err error + tempPath := path + if runtime.GOOS == "windows" { + tempPath, err = filecopy.GetTempCopy(path) + if err != nil { + log.Err(err).Msgf("获取临时拷贝文件 %s 失败", path) + return nil, err + } + } + db, err = sql.Open("sqlite3", tempPath) + if err != nil { + log.Err(err).Msgf("连接数据库 %s 失败", path) + return nil, err + } + d.mutex.Lock() + d.dbs[path] = db + d.mutex.Unlock() + return db, nil +} + +func (d *DBManager) Callback(event fsnotify.Event) error { + if !event.Op.Has(fsnotify.Create) { + return nil + } + + d.mutex.Lock() + db, ok := d.dbs[event.Name] + if ok { + delete(d.dbs, event.Name) + go func(db *sql.DB) { + time.Sleep(time.Second * 5) + db.Close() + }(db) + } + d.mutex.Unlock() + + return nil +} + +func (d *DBManager) Start() error { + return d.fm.Start() +} + +func (d *DBManager) Stop() error { + return d.fm.Stop() +} + +func (d *DBManager) Close() error { + for _, db := range d.dbs { + db.Close() + } + return d.fm.Stop() +} diff --git a/internal/wechatdb/datasource/dbm/dbm_test.go b/internal/wechatdb/datasource/dbm/dbm_test.go new file mode 100644 index 0000000..25764a9 --- /dev/null +++ b/internal/wechatdb/datasource/dbm/dbm_test.go @@ -0,0 +1,42 @@ +package dbm + +import ( + "fmt" + "testing" + "time" +) + +func TestXxx(t *testing.T) { + path := "/Users/sarv/Documents/chatlog/bigjun_9e7a" + + g := Group{ + Name: "session", + Pattern: `session\.db$`, + BlackList: []string{}, + } + + d := NewDBManager(path) + d.AddGroup(g) + d.Start() + + i := 0 + for { + db, err := d.GetDB("session") + if err != nil { + fmt.Println(err) + break + } + + var username string + row := db.QueryRow(`SELECT username FROM SessionTable LIMIT 1`) + if err := row.Scan(&username); err != nil { + fmt.Printf("Error scanning row: %v\n", err) + time.Sleep(100 * time.Millisecond) + continue + } + fmt.Printf("%d: Username: %s\n", i, username) + i++ + time.Sleep(1000 * time.Millisecond) + } + +} diff --git a/internal/wechatdb/datasource/dbm/group.go b/internal/wechatdb/datasource/dbm/group.go new file mode 100644 index 0000000..ab9519b --- /dev/null +++ b/internal/wechatdb/datasource/dbm/group.go @@ -0,0 +1,7 @@ +package dbm + +type Group struct { + Name string + Pattern string + BlackList []string +} diff --git a/internal/wechatdb/datasource/v4/datasource.go b/internal/wechatdb/datasource/v4/datasource.go index 589cfdd..f330956 100644 --- a/internal/wechatdb/datasource/v4/datasource.go +++ b/internal/wechatdb/datasource/v4/datasource.go @@ -10,22 +10,51 @@ import ( "strings" "time" + "github.com/fsnotify/fsnotify" _ "github.com/mattn/go-sqlite3" "github.com/rs/zerolog/log" "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" - "github.com/sjzar/chatlog/pkg/util" + "github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm" ) const ( - MessageFilePattern = "^message_([0-9]?[0-9])?\\.db$" - ContactFilePattern = "^contact\\.db$" - SessionFilePattern = "^session\\.db$" - MediaFilePattern = "^hardlink\\.db$" - VoiceFilePattern = "^media_([0-9]?[0-9])?\\.db$" + Message = "message" + Contact = "contact" + Session = "session" + Media = "media" + Voice = "voice" ) +var Groups = []dbm.Group{ + { + Name: Message, + Pattern: `^message_([0-9]?[0-9])?\.db$`, + BlackList: []string{}, + }, + { + Name: Contact, + Pattern: `^contact\.db$`, + BlackList: []string{}, + }, + { + Name: Session, + Pattern: `session\.db$`, + BlackList: []string{}, + }, + { + Name: Media, + Pattern: `^hardlink\.db$`, + BlackList: []string{}, + }, + { + Name: Voice, + Pattern: `^media_([0-9]?[0-9])?\.db$`, + BlackList: []string{}, + }, +} + // MessageDBInfo 存储消息数据库的信息 type MessageDBInfo struct { FilePath string @@ -34,61 +63,65 @@ type MessageDBInfo struct { } type DataSource struct { - path string - messageDbs map[string]*sql.DB - contactDb *sql.DB - sessionDb *sql.DB - mediaDb *sql.DB - voiceDb []*sql.DB + path string + dbm *dbm.DBManager // 消息数据库信息 - messageFiles []MessageDBInfo + messageInfos []MessageDBInfo } func New(path string) (*DataSource, error) { + ds := &DataSource{ path: path, - messageDbs: make(map[string]*sql.DB), - voiceDb: make([]*sql.DB, 0), - messageFiles: make([]MessageDBInfo, 0), + dbm: dbm.NewDBManager(path), + messageInfos: make([]MessageDBInfo, 0), } - if err := ds.initMessageDbs(path); err != nil { - return nil, errors.DBInitFailed(err) - } - if err := ds.initContactDb(path); err != nil { - return nil, errors.DBInitFailed(err) - } - if err := ds.initSessionDb(path); err != nil { - return nil, errors.DBInitFailed(err) - } - if err := ds.initMediaDb(path); err != nil { - return nil, errors.DBInitFailed(err) - } - if err := ds.initVoiceDb(path); err != nil { + for _, g := range Groups { + ds.dbm.AddGroup(g) + } + + if err := ds.dbm.Start(); err != nil { + return nil, err + } + + if err := ds.initMessageDbs(); err != nil { return nil, errors.DBInitFailed(err) } + ds.dbm.AddCallback(Message, func(event fsnotify.Event) error { + if !event.Op.Has(fsnotify.Create) { + return nil + } + if err := ds.initMessageDbs(); err != nil { + log.Err(err).Msgf("Failed to reinitialize message DBs: %s", event.Name) + } + return nil + }) + return ds, nil } -func (ds *DataSource) initMessageDbs(path string) error { - // 查找所有消息数据库文件 - files, err := util.FindFilesWithPatterns(path, MessageFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, MessageFilePattern, err) +func (ds *DataSource) SetCallback(name string, callback func(event fsnotify.Event) error) error { + if name == "chatroom" { + name = Contact } + return ds.dbm.AddCallback(name, callback) +} - if len(files) == 0 { - return errors.DBFileNotFound(path, MessageFilePattern, nil) +func (ds *DataSource) initMessageDbs() error { + dbPaths, err := ds.dbm.GetDBPath(Message) + if err != nil { + return err } // 处理每个数据库文件 - for _, filePath := range files { - // 连接数据库 - db, err := sql.Open("sqlite3", filePath) + infos := make([]MessageDBInfo, 0) + for _, filePath := range dbPaths { + db, err := ds.dbm.OpenDB(filePath) if err != nil { - log.Err(err).Msgf("连接数据库 %s 失败", filePath) + log.Err(err).Msgf("获取数据库 %s 失败", filePath) continue } @@ -99,108 +132,38 @@ func (ds *DataSource) initMessageDbs(path string) error { row := db.QueryRow("SELECT timestamp FROM Timestamp LIMIT 1") if err := row.Scan(×tamp); err != nil { log.Err(err).Msgf("获取数据库 %s 的时间戳失败", filePath) - db.Close() continue } startTime = time.Unix(timestamp, 0) // 保存数据库信息 - ds.messageFiles = append(ds.messageFiles, MessageDBInfo{ + infos = append(infos, MessageDBInfo{ FilePath: filePath, StartTime: startTime, }) - - // 保存数据库连接 - ds.messageDbs[filePath] = db } // 按照 StartTime 排序数据库文件 - sort.Slice(ds.messageFiles, func(i, j int) bool { - return ds.messageFiles[i].StartTime.Before(ds.messageFiles[j].StartTime) + sort.Slice(infos, func(i, j int) bool { + return infos[i].StartTime.Before(infos[j].StartTime) }) // 设置结束时间 - for i := range ds.messageFiles { - if i == len(ds.messageFiles)-1 { - ds.messageFiles[i].EndTime = time.Now() + for i := range infos { + if i == len(infos)-1 { + infos[i].EndTime = time.Now() } else { - ds.messageFiles[i].EndTime = ds.messageFiles[i+1].StartTime + infos[i].EndTime = infos[i+1].StartTime } } - - return nil -} - -func (ds *DataSource) initContactDb(path string) error { - files, err := util.FindFilesWithPatterns(path, ContactFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, ContactFilePattern, err) - } - - if len(files) == 0 { - return errors.DBFileNotFound(path, ContactFilePattern, nil) - } - - ds.contactDb, err = sql.Open("sqlite3", files[0]) - if err != nil { - return errors.DBConnectFailed(files[0], err) - } - - return nil -} - -func (ds *DataSource) initSessionDb(path string) error { - files, err := util.FindFilesWithPatterns(path, SessionFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, SessionFilePattern, err) - } - if len(files) == 0 { - return errors.DBFileNotFound(path, SessionFilePattern, nil) - } - ds.sessionDb, err = sql.Open("sqlite3", files[0]) - if err != nil { - return errors.DBConnectFailed(files[0], err) - } - return nil -} - -func (ds *DataSource) initMediaDb(path string) error { - files, err := util.FindFilesWithPatterns(path, MediaFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, MediaFilePattern, err) - } - if len(files) == 0 { - return errors.DBFileNotFound(path, MediaFilePattern, nil) - } - ds.mediaDb, err = sql.Open("sqlite3", files[0]) - if err != nil { - return errors.DBConnectFailed(files[0], err) - } - return nil -} - -func (ds *DataSource) initVoiceDb(path string) error { - files, err := util.FindFilesWithPatterns(path, VoiceFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, VoiceFilePattern, err) - } - if len(files) == 0 { - return errors.DBFileNotFound(path, VoiceFilePattern, nil) - } - for _, file := range files { - db, err := sql.Open("sqlite3", file) - if err != nil { - return errors.DBConnectFailed(files[0], err) - } - ds.voiceDb = append(ds.voiceDb, db) - } + ds.messageInfos = infos return nil } // getDBInfosForTimeRange 获取时间范围内的数据库信息 func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []MessageDBInfo { var dbs []MessageDBInfo - for _, info := range ds.messageFiles { + for _, info := range ds.messageInfos { if info.StartTime.Before(endTime) && info.EndTime.After(startTime) { dbs = append(dbs, info) } @@ -234,8 +197,8 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T return nil, err } - db, ok := ds.messageDbs[dbInfo.FilePath] - if !ok { + db, err := ds.dbm.OpenDB(dbInfo.FilePath) + if err != nil { log.Error().Msgf("数据库 %s 未打开", dbInfo.FilePath) continue } @@ -275,8 +238,8 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T // getMessagesSingleFile 从单个数据库文件获取消息 func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { - db, ok := ds.messageDbs[dbInfo.FilePath] - if !ok { + db, err := ds.dbm.OpenDB(dbInfo.FilePath) + if err != nil { return nil, errors.DBConnectFailed(dbInfo.FilePath, nil) } @@ -287,7 +250,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD // 检查表是否存在 var exists bool - err := db.QueryRowContext(ctx, + err = db.QueryRowContext(ctx, "SELECT 1 FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&exists) @@ -445,7 +408,11 @@ func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset } // 执行查询 - rows, err := ds.contactDb.QueryContext(ctx, query, args...) + db, err := ds.dbm.GetDB(Contact) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -477,13 +444,18 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse var query string var args []interface{} + // 执行查询 + db, err := ds.dbm.GetDB(Contact) + if err != nil { + return nil, err + } + if key != "" { // 按照关键字查询 query = `SELECT username, owner, ext_buffer FROM chat_room WHERE username = ?` args = []interface{}{key} - // 执行查询 - rows, err := ds.contactDb.QueryContext(ctx, query, args...) + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -510,7 +482,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse contacts, err := ds.GetContacts(ctx, key, 1, 0) if err == nil && len(contacts) > 0 && strings.HasSuffix(contacts[0].UserName, "@chatroom") { // 再次尝试通过用户名查找群聊 - rows, err := ds.contactDb.QueryContext(ctx, + rows, err := db.QueryContext(ctx, `SELECT username, owner, ext_buffer FROM chat_room WHERE username = ?`, contacts[0].UserName) @@ -560,7 +532,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse } // 执行查询 - rows, err := ds.contactDb.QueryContext(ctx, query, args...) + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -614,7 +586,11 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset } // 执行查询 - rows, err := ds.sessionDb.QueryContext(ctx, query, args...) + db, err := ds.dbm.GetDB(Session) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -678,7 +654,12 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (* query += " WHERE f.md5 = ? OR f.file_name LIKE ? || '%'" args := []interface{}{key, key} - rows, err := ds.mediaDb.QueryContext(ctx, query, args...) + // 执行查询 + db, err := ds.dbm.GetDB(Media) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -726,7 +707,12 @@ func (ds *DataSource) GetVoice(ctx context.Context, key string) (*model.Media, e ` args := []interface{}{key} - for _, db := range ds.voiceDb { + dbs, err := ds.dbm.GetDBs(Voice) + if err != nil { + return nil, errors.DBConnectFailed("", err) + } + + for _, db := range dbs { rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) @@ -755,38 +741,5 @@ func (ds *DataSource) GetVoice(ctx context.Context, key string) (*model.Media, e } func (ds *DataSource) Close() error { - var errs []error - - // 关闭消息数据库连接 - for _, db := range ds.messageDbs { - if err := db.Close(); err != nil { - errs = append(errs, err) - } - } - - // 关闭联系人数据库连接 - if ds.contactDb != nil { - if err := ds.contactDb.Close(); err != nil { - errs = append(errs, err) - } - } - - // 关闭会话数据库连接 - if ds.sessionDb != nil { - if err := ds.sessionDb.Close(); err != nil { - errs = append(errs, err) - } - } - - if ds.mediaDb != nil { - if err := ds.mediaDb.Close(); err != nil { - errs = append(errs, err) - } - } - - if len(errs) > 0 { - return errors.DBCloseFailed(errs[0]) - } - - return nil + return ds.dbm.Close() } diff --git a/internal/wechatdb/datasource/windowsv3/datasource.go b/internal/wechatdb/datasource/windowsv3/datasource.go index 41367ab..2a8434c 100644 --- a/internal/wechatdb/datasource/windowsv3/datasource.go +++ b/internal/wechatdb/datasource/windowsv3/datasource.go @@ -9,23 +9,57 @@ import ( "strings" "time" + "github.com/fsnotify/fsnotify" _ "github.com/mattn/go-sqlite3" "github.com/rs/zerolog/log" "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" - "github.com/sjzar/chatlog/pkg/util" + "github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm" ) const ( - MessageFilePattern = "^MSG([0-9]?[0-9])?\\.db$" - ContactFilePattern = "^MicroMsg.db$" - ImageFilePattern = "^HardLinkImage\\.db$" - VideoFilePattern = "^HardLinkVideo\\.db$" - FileFilePattern = "^HardLinkFile\\.db$" - VoiceFilePattern = "^MediaMSG([0-9])?\\.db$" + Message = "message" + Contact = "contact" + Image = "image" + Video = "video" + File = "file" + Voice = "voice" ) +var Groups = []dbm.Group{ + { + Name: Message, + Pattern: `^MSG([0-9]?[0-9])?\.db$`, + BlackList: []string{}, + }, + { + Name: Contact, + Pattern: `^MicroMsg.db$`, + BlackList: []string{}, + }, + { + Name: Image, + Pattern: `^HardLinkImage\.db$`, + BlackList: []string{}, + }, + { + Name: Video, + Pattern: `^HardLinkVideo\.db$`, + BlackList: []string{}, + }, + { + Name: File, + Pattern: `^HardLinkFile\.db$`, + BlackList: []string{}, + }, + { + Name: Voice, + Pattern: `^MediaMSG([0-9])?\.db$`, + BlackList: []string{}, + }, +} + // MessageDBInfo 保存消息数据库的信息 type MessageDBInfo struct { FilePath string @@ -36,67 +70,67 @@ type MessageDBInfo struct { // DataSource 实现了 DataSource 接口 type DataSource struct { - // 消息数据库 - messageFiles []MessageDBInfo - messageDbs map[string]*sql.DB + path string + dbm *dbm.DBManager - // 联系人数据库 - contactDbFile string - contactDb *sql.DB - - imageDb *sql.DB - videoDb *sql.DB - fileDb *sql.DB - voiceDb []*sql.DB + // 消息数据库信息 + messageInfos []MessageDBInfo } // New 创建一个新的 WindowsV3DataSource func New(path string) (*DataSource, error) { ds := &DataSource{ - messageFiles: make([]MessageDBInfo, 0), - messageDbs: make(map[string]*sql.DB), - voiceDb: make([]*sql.DB, 0), + path: path, + dbm: dbm.NewDBManager(path), + messageInfos: make([]MessageDBInfo, 0), } - // 初始化消息数据库 - if err := ds.initMessageDbs(path); err != nil { + for _, g := range Groups { + ds.dbm.AddGroup(g) + } + + if err := ds.dbm.Start(); err != nil { + return nil, err + } + + if err := ds.initMessageDbs(); err != nil { return nil, errors.DBInitFailed(err) } - // 初始化联系人数据库 - if err := ds.initContactDb(path); err != nil { - return nil, errors.DBInitFailed(err) - } - - if err := ds.initMediaDb(path); err != nil { - return nil, errors.DBInitFailed(err) - } - - if err := ds.initVoiceDb(path); err != nil { - return nil, errors.DBInitFailed(err) - } + ds.dbm.AddCallback(Message, func(event fsnotify.Event) error { + if !event.Op.Has(fsnotify.Create) { + return nil + } + if err := ds.initMessageDbs(); err != nil { + log.Err(err).Msgf("Failed to reinitialize message DBs: %s", event.Name) + } + return nil + }) return ds, nil } -// initMessageDbs 初始化消息数据库 -func (ds *DataSource) initMessageDbs(path string) error { - // 查找所有消息数据库文件 - files, err := util.FindFilesWithPatterns(path, MessageFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, MessageFilePattern, err) +func (ds *DataSource) SetCallback(name string, callback func(event fsnotify.Event) error) error { + if name == "chatroom" { + name = Contact } + return ds.dbm.AddCallback(name, callback) +} - if len(files) == 0 { - return errors.DBFileNotFound(path, MessageFilePattern, nil) +// initMessageDbs 初始化消息数据库 +func (ds *DataSource) initMessageDbs() error { + // 获取所有消息数据库文件路径 + dbPaths, err := ds.dbm.GetDBPath(Message) + if err != nil { + return err } // 处理每个数据库文件 - for _, filePath := range files { - // 连接数据库 - db, err := sql.Open("sqlite3", filePath) + infos := make([]MessageDBInfo, 0) + for _, filePath := range dbPaths { + db, err := ds.dbm.OpenDB(filePath) if err != nil { - log.Err(err).Msgf("连接数据库 %s 失败", filePath) + log.Err(err).Msgf("获取数据库 %s 失败", filePath) continue } @@ -106,7 +140,6 @@ func (ds *DataSource) initMessageDbs(path string) error { rows, err := db.Query("SELECT tableIndex, tableVersion, tableDesc FROM DBInfo") if err != nil { log.Err(err).Msgf("查询数据库 %s 的 DBInfo 表失败", filePath) - db.Close() continue } @@ -133,7 +166,6 @@ func (ds *DataSource) initMessageDbs(path string) error { rows, err = db.Query("SELECT UsrName FROM Name2ID") if err != nil { log.Err(err).Msgf("查询数据库 %s 的 Name2ID 表失败", filePath) - db.Close() continue } @@ -150,123 +182,34 @@ func (ds *DataSource) initMessageDbs(path string) error { rows.Close() // 保存数据库信息 - ds.messageFiles = append(ds.messageFiles, MessageDBInfo{ + infos = append(infos, MessageDBInfo{ FilePath: filePath, StartTime: startTime, TalkerMap: talkerMap, }) - - // 保存数据库连接 - ds.messageDbs[filePath] = db } // 按照 StartTime 排序数据库文件 - sort.Slice(ds.messageFiles, func(i, j int) bool { - return ds.messageFiles[i].StartTime.Before(ds.messageFiles[j].StartTime) + sort.Slice(infos, func(i, j int) bool { + return infos[i].StartTime.Before(infos[j].StartTime) }) // 设置结束时间 - for i := range ds.messageFiles { - if i == len(ds.messageFiles)-1 { - ds.messageFiles[i].EndTime = time.Now() + for i := range infos { + if i == len(infos)-1 { + infos[i].EndTime = time.Now() } else { - ds.messageFiles[i].EndTime = ds.messageFiles[i+1].StartTime + infos[i].EndTime = infos[i+1].StartTime } } - - return nil -} - -// initContactDb 初始化联系人数据库 -func (ds *DataSource) initContactDb(path string) error { - files, err := util.FindFilesWithPatterns(path, ContactFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, ContactFilePattern, err) - } - - if len(files) == 0 { - return errors.DBFileNotFound(path, ContactFilePattern, nil) - } - - ds.contactDbFile = files[0] - - ds.contactDb, err = sql.Open("sqlite3", ds.contactDbFile) - if err != nil { - return errors.DBConnectFailed(ds.contactDbFile, err) - } - - return nil -} - -// initContactDb 初始化联系人数据库 -func (ds *DataSource) initMediaDb(path string) error { - files, err := util.FindFilesWithPatterns(path, ImageFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, ImageFilePattern, err) - } - - if len(files) == 0 { - return errors.DBFileNotFound(path, ImageFilePattern, nil) - } - - ds.imageDb, err = sql.Open("sqlite3", files[0]) - if err != nil { - return errors.DBConnectFailed(files[0], err) - } - - files, err = util.FindFilesWithPatterns(path, VideoFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, VideoFilePattern, err) - } - - if len(files) == 0 { - return errors.DBFileNotFound(path, VideoFilePattern, nil) - } - - ds.videoDb, err = sql.Open("sqlite3", files[0]) - if err != nil { - return errors.DBConnectFailed(files[0], err) - } - - files, err = util.FindFilesWithPatterns(path, FileFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, FileFilePattern, err) - } - - if len(files) == 0 { - return errors.DBFileNotFound(path, FileFilePattern, nil) - } - - ds.fileDb, err = sql.Open("sqlite3", files[0]) - if err != nil { - return errors.DBConnectFailed(files[0], err) - } - - return nil -} - -func (ds *DataSource) initVoiceDb(path string) error { - files, err := util.FindFilesWithPatterns(path, VoiceFilePattern, true) - if err != nil { - return errors.DBFileNotFound(path, VoiceFilePattern, err) - } - if len(files) == 0 { - return errors.DBFileNotFound(path, VoiceFilePattern, nil) - } - for _, file := range files { - db, err := sql.Open("sqlite3", file) - if err != nil { - return errors.DBConnectFailed(files[0], err) - } - ds.voiceDb = append(ds.voiceDb, db) - } + ds.messageInfos = infos return nil } // getDBInfosForTimeRange 获取时间范围内的数据库信息 func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []MessageDBInfo { var dbs []MessageDBInfo - for _, info := range ds.messageFiles { + for _, info := range ds.messageInfos { if info.StartTime.Before(endTime) && info.EndTime.After(startTime) { dbs = append(dbs, info) } @@ -296,70 +239,19 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T return nil, err } - db, ok := ds.messageDbs[dbInfo.FilePath] - if !ok { + db, err := ds.dbm.OpenDB(dbInfo.FilePath) + if err != nil { log.Error().Msgf("数据库 %s 未打开", dbInfo.FilePath) continue } - // 构建查询条件 - conditions := []string{"Sequence >= ? AND Sequence <= ?"} - args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000} - - if len(talker) > 0 { - talkerID, ok := dbInfo.TalkerMap[talker] - if ok { - conditions = append(conditions, "TalkerId = ?") - args = append(args, talkerID) - } else { - conditions = append(conditions, "StrTalker = ?") - args = append(args, talker) - } - } - - query := fmt.Sprintf(` - SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender, - Type, SubType, StrContent, CompressContent, BytesExtra - FROM MSG - WHERE %s - ORDER BY Sequence ASC - `, strings.Join(conditions, " AND ")) - - // 执行查询 - rows, err := db.QueryContext(ctx, query, args...) + messages, err := ds.getMessagesFromDB(ctx, db, dbInfo, startTime, endTime, talker) if err != nil { - log.Err(err).Msgf("查询数据库 %s 失败", dbInfo.FilePath) + log.Err(err).Msgf("从数据库 %s 获取消息失败", dbInfo.FilePath) continue } - // 处理查询结果 - for rows.Next() { - var msg model.MessageV3 - var compressContent []byte - var bytesExtra []byte - - err := rows.Scan( - &msg.MsgSvrID, - &msg.Sequence, - &msg.CreateTime, - &msg.StrTalker, - &msg.IsSender, - &msg.Type, - &msg.SubType, - &msg.StrContent, - &compressContent, - &bytesExtra, - ) - if err != nil { - log.Err(err).Msg("扫描消息行失败") - continue - } - msg.CompressContent = compressContent - msg.BytesExtra = bytesExtra - - totalMessages = append(totalMessages, msg.Wrap()) - } - rows.Close() + totalMessages = append(totalMessages, messages...) if limit+offset > 0 && len(totalMessages) >= limit+offset { break @@ -388,6 +280,11 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T // getMessagesSingleFile 从单个数据库文件获取消息 func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { + db, err := ds.dbm.OpenDB(dbInfo.FilePath) + if err != nil { + return nil, errors.DBConnectFailed(dbInfo.FilePath, nil) + } + // 构建查询条件 conditions := []string{"Sequence >= ? AND Sequence <= ?"} args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000} @@ -419,7 +316,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD } // 执行查询 - rows, err := ds.messageDbs[dbInfo.FilePath].QueryContext(ctx, query, args...) + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -453,6 +350,69 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD return totalMessages, nil } +// getMessagesFromDB 从数据库获取消息 +func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string) ([]*model.Message, error) { + // 构建查询条件 + conditions := []string{"Sequence >= ? AND Sequence <= ?"} + args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000} + + if len(talker) > 0 { + talkerID, ok := dbInfo.TalkerMap[talker] + if ok { + conditions = append(conditions, "TalkerId = ?") + args = append(args, talkerID) + } else { + conditions = append(conditions, "StrTalker = ?") + args = append(args, talker) + } + } + + query := fmt.Sprintf(` + SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender, + Type, SubType, StrContent, CompressContent, BytesExtra + FROM MSG + WHERE %s + ORDER BY Sequence ASC + `, strings.Join(conditions, " AND ")) + + // 执行查询 + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return nil, errors.QueryFailed(query, err) + } + defer rows.Close() + + // 处理查询结果 + messages := []*model.Message{} + for rows.Next() { + var msg model.MessageV3 + var compressContent []byte + var bytesExtra []byte + + err := rows.Scan( + &msg.MsgSvrID, + &msg.Sequence, + &msg.CreateTime, + &msg.StrTalker, + &msg.IsSender, + &msg.Type, + &msg.SubType, + &msg.StrContent, + &compressContent, + &bytesExtra, + ) + if err != nil { + return nil, errors.ScanRowFailed(err) + } + msg.CompressContent = compressContent + msg.BytesExtra = bytesExtra + + messages = append(messages, msg.Wrap()) + } + + return messages, nil +} + // GetContacts 实现获取联系人信息的方法 func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset int) ([]*model.Contact, error) { var query string @@ -478,7 +438,11 @@ func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset } // 执行查询 - rows, err := ds.contactDb.QueryContext(ctx, query, args...) + db, err := ds.dbm.GetDB(Contact) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -516,7 +480,11 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse args = []interface{}{key} // 执行查询 - rows, err := ds.contactDb.QueryContext(ctx, query, args...) + db, err := ds.dbm.GetDB(Contact) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -543,7 +511,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse contacts, err := ds.GetContacts(ctx, key, 1, 0) if err == nil && len(contacts) > 0 && strings.HasSuffix(contacts[0].UserName, "@chatroom") { // 再次尝试通过用户名查找群聊 - rows, err := ds.contactDb.QueryContext(ctx, + rows, err := db.QueryContext(ctx, `SELECT ChatRoomName, Reserved2, RoomData FROM ChatRoom WHERE ChatRoomName = ?`, contacts[0].UserName) @@ -593,7 +561,11 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse } // 执行查询 - rows, err := ds.contactDb.QueryContext(ctx, query, args...) + db, err := ds.dbm.GetDB(Contact) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -647,7 +619,11 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset } // 执行查询 - rows, err := ds.contactDb.QueryContext(ctx, query, args...) + db, err := ds.dbm.GetDB(Contact) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) } @@ -688,25 +664,29 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (* return nil, errors.DecodeKeyFailed(err) } - var db *sql.DB + var dbType string var table1, table2 string switch _type { case "image": - db = ds.imageDb + dbType = Image table1 = "HardLinkImageAttribute" table2 = "HardLinkImageID" case "video": - db = ds.videoDb + dbType = Video table1 = "HardLinkVideoAttribute" table2 = "HardLinkVideoID" case "file": - db = ds.fileDb + dbType = File table1 = "HardLinkFileAttribute" table2 = "HardLinkFileID" default: return nil, errors.MediaTypeUnsupported(_type) + } + db, err := ds.dbm.GetDB(dbType) + if err != nil { + return nil, err } query := fmt.Sprintf(` @@ -768,7 +748,12 @@ func (ds *DataSource) GetVoice(ctx context.Context, key string) (*model.Media, e ` args := []interface{}{key} - for _, db := range ds.voiceDb { + dbs, err := ds.dbm.GetDBs(Voice) + if err != nil { + return nil, errors.DBConnectFailed("", err) + } + + for _, db := range dbs { rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, errors.QueryFailed(query, err) @@ -798,41 +783,5 @@ func (ds *DataSource) GetVoice(ctx context.Context, key string) (*model.Media, e // Close 实现 DataSource 接口的 Close 方法 func (ds *DataSource) Close() error { - var errs []error - - // 关闭消息数据库连接 - for _, db := range ds.messageDbs { - if err := db.Close(); err != nil { - errs = append(errs, err) - } - } - - // 关闭联系人数据库连接 - if ds.contactDb != nil { - if err := ds.contactDb.Close(); err != nil { - errs = append(errs, err) - } - } - - if ds.imageDb != nil { - if err := ds.imageDb.Close(); err != nil { - errs = append(errs, err) - } - } - if ds.videoDb != nil { - if err := ds.videoDb.Close(); err != nil { - errs = append(errs, err) - } - } - if ds.fileDb != nil { - if err := ds.fileDb.Close(); err != nil { - errs = append(errs, err) - } - } - - if len(errs) > 0 { - return errors.DBCloseFailed(errs[0]) - } - - return nil + return ds.dbm.Close() } diff --git a/internal/wechatdb/repository/repository.go b/internal/wechatdb/repository/repository.go index 6fa8ce3..5cdcbca 100644 --- a/internal/wechatdb/repository/repository.go +++ b/internal/wechatdb/repository/repository.go @@ -3,6 +3,9 @@ package repository import ( "context" + "github.com/fsnotify/fsnotify" + "github.com/rs/zerolog/log" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/internal/wechatdb/datasource" @@ -61,6 +64,9 @@ func New(ds datasource.DataSource) (*Repository, error) { return nil, errors.InitCacheFailed(err) } + ds.SetCallback("contact", r.contactCallback) + ds.SetCallback("chatroom", r.chatroomCallback) + return r, nil } @@ -79,6 +85,26 @@ func (r *Repository) initCache(ctx context.Context) error { return nil } +func (r *Repository) contactCallback(event fsnotify.Event) error { + if !event.Op.Has(fsnotify.Create) { + return nil + } + if err := r.initContactCache(context.Background()); err != nil { + log.Err(err).Msgf("Failed to reinitialize contact cache: %s", event.Name) + } + return nil +} + +func (r *Repository) chatroomCallback(event fsnotify.Event) error { + if !event.Op.Has(fsnotify.Create) { + return nil + } + if err := r.initChatRoomCache(context.Background()); err != nil { + log.Err(err).Msgf("Failed to reinitialize contact cache: %s", event.Name) + } + return nil +} + // Close 实现 Repository 接口的 Close 方法 func (r *Repository) Close() error { return r.ds.Close() diff --git a/pkg/filecopy/filecopy.go b/pkg/filecopy/filecopy.go new file mode 100644 index 0000000..8c11a74 --- /dev/null +++ b/pkg/filecopy/filecopy.go @@ -0,0 +1,628 @@ +package filecopy + +import ( + "encoding/json" + "fmt" + "hash/fnv" + "io" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +var ( + // Singleton locks to ensure only one thread processes the same file at a time + fileOperationLocks = make(map[string]*sync.Mutex) + locksMutex = sync.RWMutex{} + + // Mapping from original file paths to temporary file paths + pathToTempFile = make(map[string]string) + // Metadata information for original files + fileMetadata = make(map[string]fileMetaInfo) + // Track old versions of temporary files for each original file + oldVersions = make(map[string]string) + mapMutex = sync.RWMutex{} + + // Temporary directory + tempDir string + // Path to the mapping file + mappingFilePath string + + // Channel for delayed file deletion + fileDeletionChan = make(chan FileDeletion, 1000) + + // Default deletion delay time (30 seconds) + DefaultDeletionDelay = 30 * time.Second +) + +type FileDeletion struct { + Path string + Time time.Time +} + +// File metadata information +type fileMetaInfo struct { + ModTime time.Time `json:"mod_time"` + Size int64 `json:"size"` +} + +// Persistent mapping information +type persistentMapping struct { + OriginalPath string `json:"original_path"` + TempPath string `json:"temp_path"` + Metadata fileMetaInfo `json:"metadata"` +} + +// Initialize temporary directory +func initTempDir() { + // Get process name to create a unique temporary directory + procName := getProcessName() + tempDir = filepath.Join(os.TempDir(), "filecopy_"+procName) + + if err := os.MkdirAll(tempDir, 0755); err != nil { + tempDir = filepath.Join(os.TempDir(), "filecopy") + if err := os.MkdirAll(tempDir, 0755); err != nil { + tempDir = os.TempDir() + } + } + + // Set mapping file path + mappingFilePath = filepath.Join(tempDir, "file_mappings.json") + + // Load existing mappings if available + loadMappings() + + // Scan and clean existing temporary files + cleanupExistingTempFiles() +} + +// Get process name +func getProcessName() string { + executable, err := os.Executable() + if err != nil { + return "unknown" + } + + // Extract base name (without extension) + baseName := filepath.Base(executable) + ext := filepath.Ext(baseName) + if ext != "" { + baseName = baseName[:len(baseName)-len(ext)] + } + + // Clean name, keep only letters, numbers, underscores and hyphens + baseName = strings.Map(func(r rune) rune { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' || r == '_' { + return r + } + return '_' + }, baseName) + + return baseName +} + +// Load file mappings from persistent storage +func loadMappings() { + file, err := os.Open(mappingFilePath) + if err != nil { + // It's okay if the file doesn't exist yet + return + } + defer file.Close() + + var mappings []persistentMapping + decoder := json.NewDecoder(file) + if err := decoder.Decode(&mappings); err != nil { + // If the file is corrupted, we'll just start fresh + return + } + + // Restore mappings + mapMutex.Lock() + defer mapMutex.Unlock() + + for _, mapping := range mappings { + // Verify that both the original file and temp file still exist + origStat, origErr := os.Stat(mapping.OriginalPath) + _, tempErr := os.Stat(mapping.TempPath) + + if origErr == nil && tempErr == nil { + // Check if the original file has changed since the mapping was saved + if origStat.ModTime() == mapping.Metadata.ModTime && origStat.Size() == mapping.Metadata.Size { + // The mapping is still valid + pathToTempFile[mapping.OriginalPath] = mapping.TempPath + fileMetadata[mapping.OriginalPath] = mapping.Metadata + } + } + } +} + +// Save file mappings to persistent storage +func saveMappings() { + mapMutex.RLock() + defer mapMutex.RUnlock() + + var mappings []persistentMapping + for origPath, tempPath := range pathToTempFile { + if meta, exists := fileMetadata[origPath]; exists { + mappings = append(mappings, persistentMapping{ + OriginalPath: origPath, + TempPath: tempPath, + Metadata: meta, + }) + } + } + + // Create the file + file, err := os.Create(mappingFilePath) + if err != nil { + return + } + defer file.Close() + + // Write the mappings + encoder := json.NewEncoder(file) + encoder.SetIndent("", " ") + if err := encoder.Encode(mappings); err != nil { + // If we can't save, just continue - it's not critical + return + } +} + +// Clean up existing temporary files +func cleanupExistingTempFiles() { + files, err := os.ReadDir(tempDir) + if err != nil { + return + } + + // Skip the mapping file + mappingFileName := filepath.Base(mappingFilePath) + + // First, collect all files that are already in our mapping + knownFiles := make(map[string]bool) + mapMutex.RLock() + for _, tempPath := range pathToTempFile { + knownFiles[tempPath] = true + } + mapMutex.RUnlock() + + // Group files by prefix (baseName_hashPrefix) + fileGroups := make(map[string][]tempFileInfo) + + for _, file := range files { + if file.IsDir() { + continue + } + + fileName := file.Name() + + // Skip the mapping file + if fileName == mappingFileName { + continue + } + + filePath := filepath.Join(tempDir, fileName) + parts := strings.Split(fileName, "_") + + // Skip files that don't match our naming convention + if len(parts) < 3 { + removeFileImmediately(filePath) + continue + } + + // Extract base name and hash part as key + baseName := parts[0] + hashPart := parts[1] + groupKey := baseName + "_" + hashPart + + // Extract timestamp + timeStr := strings.Split(parts[2], ".")[0] // Remove extension part + var timestamp int64 + if _, err := fmt.Sscanf(timeStr, "%d", ×tamp); err != nil { + removeFileImmediately(filePath) + continue + } + + // Add file info to corresponding group + fileGroups[groupKey] = append(fileGroups[groupKey], tempFileInfo{ + path: filePath, + timestamp: timestamp, + }) + } + + // Process each group of files, keep only the newest one + for _, fileInfos := range fileGroups { + if len(fileInfos) == 0 { + continue + } + + // Find the newest file + var newestFile tempFileInfo + for _, fileInfo := range fileInfos { + if fileInfo.timestamp > newestFile.timestamp { + newestFile = fileInfo + } + } + + // Delete all files except the newest one + for _, fileInfo := range fileInfos { + if fileInfo.path != newestFile.path { + // If this file is already in our mapping, keep it + if knownFiles[fileInfo.path] { + continue + } + removeFileImmediately(fileInfo.path) + } + } + } +} + +// Temporary file information +type tempFileInfo struct { + path string + timestamp int64 +} + +// Get file lock +func getFileLock(path string) *sync.Mutex { + locksMutex.RLock() + lock, exists := fileOperationLocks[path] + locksMutex.RUnlock() + + if exists { + return lock + } + + locksMutex.Lock() + defer locksMutex.Unlock() + + // Check again, might have been created while we were acquiring the write lock + lock, exists = fileOperationLocks[path] + if !exists { + lock = &sync.Mutex{} + fileOperationLocks[path] = lock + } + + return lock +} + +// GetTempCopy returns a temporary copy path of the original file +// If the file hasn't changed since the last copy, returns the existing copy +func GetTempCopy(originalPath string) (string, error) { + // Get the operation lock for this file to ensure thread safety + fileLock := getFileLock(originalPath) + fileLock.Lock() + defer fileLock.Unlock() + + // Check if original file exists + stat, err := os.Stat(originalPath) + if err != nil { + return "", fmt.Errorf("original file does not exist: %w", err) + } + + // Current file info + currentInfo := fileMetaInfo{ + ModTime: stat.ModTime(), + Size: stat.Size(), + } + + // Check existing mapping + mapMutex.RLock() + tempPath, pathExists := pathToTempFile[originalPath] + cachedInfo, infoExists := fileMetadata[originalPath] + mapMutex.RUnlock() + + // If we have an existing temp file and original file hasn't changed, return it + if pathExists && infoExists { + fileChanged := currentInfo.ModTime.After(cachedInfo.ModTime) || + currentInfo.Size != cachedInfo.Size + + if !fileChanged { + // Verify temp file still exists + if _, err := os.Stat(tempPath); err == nil { + // Try to open file to verify accessibility + if file, err := os.Open(tempPath); err == nil { + file.Close() + return tempPath, nil + } + } + } + } + + // Generate new temp file path + fileName := filepath.Base(originalPath) + fileExt := filepath.Ext(fileName) + baseName := fileName[:len(fileName)-len(fileExt)] + if baseName == "" { + baseName = "file" // Use default name if empty + } + + // Generate hash for original path + pathHash := hashString(originalPath) + hashPrefix := getHashPrefix(pathHash, 8) + + // Format: basename_pathhash_timestamp.ext + timestamp := time.Now().UnixNano() + tempPath = filepath.Join(tempDir, + fmt.Sprintf("%s_%s_%d%s", + baseName, + hashPrefix, + timestamp, + fileExt)) + + // Copy file (with retry mechanism) + if err := copyFileWithRetry(originalPath, tempPath, 3); err != nil { + return "", err + } + + // Update mappings + mapMutex.Lock() + oldPath := pathToTempFile[originalPath] + + // If there's an old path and it's different, move it to old versions and schedule for deletion + if oldPath != "" && oldPath != tempPath { + // First clean up previous old version (if any) + if oldVersionPath, hasOldVersion := oldVersions[originalPath]; hasOldVersion && oldVersionPath != oldPath { + removeFileImmediately(oldVersionPath) + } + + // Set current version as old version + oldVersions[originalPath] = oldPath + scheduleForDeletion(oldPath) + } + + // Update to new temp file + pathToTempFile[originalPath] = tempPath + fileMetadata[originalPath] = currentInfo + mapMutex.Unlock() + + // Save mappings to persistent storage + go saveMappings() + + // Immediately clean up any other related temp files + go cleanupRelatedTempFiles(originalPath, tempPath, oldPath) + + return tempPath, nil +} + +// Immediately clean up other temp files related to the specified original file +func cleanupRelatedTempFiles(originalPath, currentTempPath, knownOldPath string) { + // Extract hash prefix of original file to match related files + fileName := filepath.Base(originalPath) + fileExt := filepath.Ext(fileName) + baseName := fileName[:len(fileName)-len(fileExt)] + if baseName == "" { + baseName = "file" + } + + pathHash := hashString(originalPath) + hashPrefix := getHashPrefix(pathHash, 8) + + // File name prefix pattern + filePrefix := baseName + "_" + hashPrefix + + currentTempPathNoExt := strings.TrimSuffix(currentTempPath, filepath.Ext(currentTempPath)) + knownOldPathNoExt := strings.TrimSuffix(knownOldPath, filepath.Ext(knownOldPath)) + + files, err := os.ReadDir(tempDir) + if err != nil { + return + } + + for _, file := range files { + if file.IsDir() { + continue + } + + fileName := file.Name() + + // Skip the mapping file + if fileName == filepath.Base(mappingFilePath) { + continue + } + + filePath := filepath.Join(tempDir, fileName) + filePathNoExt := strings.TrimSuffix(filePath, filepath.Ext(filePath)) + + // Skip current file and known old version + if filePathNoExt == currentTempPathNoExt || filePathNoExt == knownOldPathNoExt { + continue + } + + // If file name matches our pattern, delete it immediately + if strings.HasPrefix(fileName, filePrefix) { + removeFileImmediately(filePath) + } + } +} + +// Immediately delete file without waiting for delay +func removeFileImmediately(path string) { + if path == "" { + return + } + + // Try to delete file + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + // Silently fail if we can't delete + } +} + +// Schedule file for delayed deletion +func scheduleForDeletion(path string) { + if path == "" { + return + } + + // Check if file exists + if _, err := os.Stat(path); os.IsNotExist(err) { + return + } + + // Put file in deletion channel + select { + case fileDeletionChan <- FileDeletion{Path: path, Time: time.Now().Add(DefaultDeletionDelay)}: + // Successfully scheduled + default: + // If channel is full, delete file immediately + removeFileImmediately(path) + } +} + +// File deletion handler +func fileDeletionHandler() { + for { + // Get file to delete from channel + file := <-fileDeletionChan + + if !time.Now().After(file.Time) { + time.Sleep(time.Until(file.Time)) + } + + // Ensure file is not in active mappings + isActive := false + mapMutex.RLock() + for _, activePath := range pathToTempFile { + if activePath == file.Path { + isActive = true + break + } + } + + mapMutex.RUnlock() + + if isActive { + continue + } + + // Delete file + removeFileImmediately(file.Path) + } +} + +// CleanupTempFiles cleans up unused temporary files +func CleanupTempFiles() { + files, err := os.ReadDir(tempDir) + if err != nil { + return + } + + // Skip the mapping file + mappingFileName := filepath.Base(mappingFilePath) + + // Get current active temp file paths and old version paths + mapMutex.RLock() + activeTempFiles := make(map[string]bool) + for _, tempFilePath := range pathToTempFile { + tempFilePath = strings.TrimSuffix(tempFilePath, filepath.Ext(tempFilePath)) + activeTempFiles[tempFilePath] = true + } + for _, oldVersionPath := range oldVersions { + oldVersionPath = strings.TrimSuffix(oldVersionPath, filepath.Ext(oldVersionPath)) + activeTempFiles[oldVersionPath] = true + } + mapMutex.RUnlock() + + // Schedule deletion of inactive temp files + for _, file := range files { + if file.IsDir() { + continue + } + + fileName := file.Name() + + // Skip the mapping file + if fileName == mappingFileName { + continue + } + + tempFilePath := filepath.Join(tempDir, fileName) + tempFilePath = strings.TrimSuffix(tempFilePath, filepath.Ext(tempFilePath)) + if !activeTempFiles[tempFilePath] { + scheduleForDeletion(tempFilePath) + } + } +} + +// Copy file with retry mechanism +func copyFileWithRetry(src, dst string, maxRetries int) error { + var err error + for i := 0; i < maxRetries; i++ { + err = copyFile(src, dst) + if err == nil { + return nil + } + + // Wait before retrying + time.Sleep(time.Duration(100*(i+1)) * time.Millisecond) + } + return fmt.Errorf("failed to copy file after %d attempts: %w", maxRetries, err) +} + +// Copy file +func copyFile(src, dst string) error { + in, err := os.Open(src) + if err != nil { + return fmt.Errorf("failed to open source file: %w", err) + } + defer in.Close() + + out, err := os.Create(dst) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer func() { + cerr := out.Close() + if err == nil && cerr != nil { + err = fmt.Errorf("failed to close destination file: %w", cerr) + } + }() + + // Use buffered copy for better performance + buf := make([]byte, 256*1024) // 256KB buffer + if _, err = io.CopyBuffer(out, in, buf); err != nil { + return fmt.Errorf("failed to copy file contents: %w", err) + } + + return out.Sync() +} + +// Generate hash for string +func hashString(s string) string { + h := fnv.New32a() + h.Write([]byte(s)) + return fmt.Sprintf("%x", h.Sum32()) +} + +// Safely get hash prefix, avoid index out of bounds +func getHashPrefix(hash string, length int) string { + if len(hash) <= length { + return hash + } + return hash[:length] +} + +// Initialize temp directory and start background cleanup +func init() { + // Initialize temp directory and scan existing files + initTempDir() + + // Start multiple file deletion handlers + for i := 0; i < 2; i++ { + go fileDeletionHandler() + } + + // Start periodic cleanup routine + go func() { + for { + time.Sleep(30 * time.Second) + CleanupTempFiles() + + // Also periodically save mappings + saveMappings() + } + }() +} diff --git a/pkg/filemonitor/filegroup.go b/pkg/filemonitor/filegroup.go new file mode 100644 index 0000000..11e64de --- /dev/null +++ b/pkg/filemonitor/filegroup.go @@ -0,0 +1,182 @@ +package filemonitor + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + + "github.com/fsnotify/fsnotify" + "github.com/rs/zerolog/log" +) + +// FileChangeCallback defines the callback function signature for file change events +type FileChangeCallback func(event fsnotify.Event) error + +// FileGroup represents a group of files with the same processing logic +type FileGroup struct { + ID string // Unique identifier + RootDir string // Root directory + Pattern *regexp.Regexp // File matching pattern + PatternStr string // Original pattern string for rebuilding + Blacklist []string // Blacklist patterns + Callbacks []FileChangeCallback // File change callbacks + mutex sync.RWMutex // Concurrency control +} + +// NewFileGroup creates a new file group +func NewFileGroup(id, rootDir, pattern string, blacklist []string) (*FileGroup, error) { + // Compile the regular expression + re, err := regexp.Compile(pattern) + if err != nil { + return nil, fmt.Errorf("invalid pattern '%s': %w", pattern, err) + } + + // Normalize root directory path + rootDir = filepath.Clean(rootDir) + + return &FileGroup{ + ID: id, + RootDir: rootDir, + Pattern: re, + PatternStr: pattern, + Blacklist: blacklist, + Callbacks: []FileChangeCallback{}, + }, nil +} + +// AddCallback adds a callback function to the file group +func (fg *FileGroup) AddCallback(callback FileChangeCallback) { + fg.mutex.Lock() + defer fg.mutex.Unlock() + + fg.Callbacks = append(fg.Callbacks, callback) +} + +// RemoveCallback removes a callback function from the file group +func (fg *FileGroup) RemoveCallback(callbackToRemove FileChangeCallback) bool { + fg.mutex.Lock() + defer fg.mutex.Unlock() + + for i, callback := range fg.Callbacks { + // Compare function addresses + if fmt.Sprintf("%p", callback) == fmt.Sprintf("%p", callbackToRemove) { + // Remove the callback + fg.Callbacks = append(fg.Callbacks[:i], fg.Callbacks[i+1:]...) + return true + } + } + + return false +} + +// Match checks if a file path matches this group's criteria +func (fg *FileGroup) Match(path string) bool { + // Normalize paths for comparison + path = filepath.Clean(path) + rootDir := filepath.Clean(fg.RootDir) + + // Check if path is under root directory + // Use filepath.Rel to handle path comparison safely across different OSes + relPath, err := filepath.Rel(rootDir, path) + if err != nil || strings.HasPrefix(relPath, "..") { + return false + } + + // Check if file matches pattern + if !fg.Pattern.MatchString(filepath.Base(path)) { + return false + } + + // Check blacklist + for _, blackItem := range fg.Blacklist { + if strings.Contains(relPath, blackItem) { + return false + } + } + + return true +} + +// List returns a list of files in the group (real-time scan) +func (fg *FileGroup) List() ([]string, error) { + files := []string{} + + // Scan directory for matching files using fs.WalkDir + err := fs.WalkDir(os.DirFS(fg.RootDir), ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return fs.SkipDir + } + + // Skip directories + if d.IsDir() { + return nil + } + + // Convert relative path to absolute + absPath := filepath.Join(fg.RootDir, path) + + // Use Match function to check if file belongs to this group + if fg.Match(absPath) { + files = append(files, absPath) + } + + return nil + }) + + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return nil, fmt.Errorf("error listing files: %w", err) + } + + return files, nil +} + +// ListMatchingDirectories returns directories containing matching files +func (fg *FileGroup) ListMatchingDirectories() (map[string]bool, error) { + directories := make(map[string]bool) + + // Get matching files + files, err := fg.List() + if err != nil { + return nil, err + } + + // Extract directories from matching files + for _, file := range files { + dir := filepath.Dir(file) + directories[dir] = true + } + + return directories, nil +} + +// HandleEvent processes a file event and triggers callbacks if the file matches +func (fg *FileGroup) HandleEvent(event fsnotify.Event) { + // Check if this event is relevant for this group + if !fg.Match(event.Name) { + return + } + + // Get callbacks under read lock + fg.mutex.RLock() + callbacks := make([]FileChangeCallback, len(fg.Callbacks)) + copy(callbacks, fg.Callbacks) + fg.mutex.RUnlock() + + // Asynchronously call callbacks + for _, callback := range callbacks { + go func(cb FileChangeCallback) { + if err := cb(event); err != nil { + log.Error(). + Str("file", event.Name). + Str("op", event.Op.String()). + Err(err). + Msg("Callback error") + } + }(callback) + } +} diff --git a/pkg/filemonitor/filemonitor.go b/pkg/filemonitor/filemonitor.go new file mode 100644 index 0000000..e21a977 --- /dev/null +++ b/pkg/filemonitor/filemonitor.go @@ -0,0 +1,430 @@ +package filemonitor + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/fsnotify/fsnotify" + "github.com/rs/zerolog/log" +) + +// FileMonitor manages multiple file groups +type FileMonitor struct { + groups map[string]*FileGroup // Map of file groups + watcher *fsnotify.Watcher // File system watcher + watchDirs map[string]bool // Monitored directories + blacklist []string // Global blacklist patterns + mutex sync.RWMutex // Concurrency control for groups and watchDirs + stopCh chan struct{} // Stop signal + wg sync.WaitGroup // Wait group + isRunning bool // Running state flag + stateMutex sync.RWMutex // State mutex +} + +func (fm *FileMonitor) Watcher() *fsnotify.Watcher { + return fm.watcher +} + +// NewFileMonitor creates a new file monitor +func NewFileMonitor() *FileMonitor { + return &FileMonitor{ + groups: make(map[string]*FileGroup), + watchDirs: make(map[string]bool), + blacklist: []string{}, + isRunning: false, + } +} + +// SetBlacklist sets the global directory blacklist +func (fm *FileMonitor) SetBlacklist(blacklist []string) { + fm.mutex.Lock() + defer fm.mutex.Unlock() + + fm.blacklist = make([]string, len(blacklist)) + copy(fm.blacklist, blacklist) +} + +// AddGroup adds a new file group +func (fm *FileMonitor) AddGroup(group *FileGroup) error { + if group == nil { + return errors.New("group cannot be nil") + } + + // First check if monitor is running + isRunning := fm.IsRunning() + + // Add group to monitor + fm.mutex.Lock() + // Check if ID already exists + if _, exists := fm.groups[group.ID]; exists { + fm.mutex.Unlock() + return fmt.Errorf("group with ID '%s' already exists", group.ID) + } + // Add to monitor + fm.groups[group.ID] = group + fm.mutex.Unlock() + + // If monitor is running, set up watching + if isRunning { + if err := fm.setupWatchForGroup(group); err != nil { + // Remove group on failure + fm.mutex.Lock() + delete(fm.groups, group.ID) + fm.mutex.Unlock() + return err + } + } + + return nil +} + +// CreateGroup creates and adds a new file group (convenience method) +func (fm *FileMonitor) CreateGroup(id, rootDir, pattern string, blacklist []string) (*FileGroup, error) { + // Create file group + group, err := NewFileGroup(id, rootDir, pattern, blacklist) + if err != nil { + return nil, err + } + + // Add to monitor + if err := fm.AddGroup(group); err != nil { + return nil, err + } + + return group, nil +} + +// RemoveGroup removes a file group +func (fm *FileMonitor) RemoveGroup(id string) error { + fm.mutex.Lock() + defer fm.mutex.Unlock() + + // Check if group exists + _, exists := fm.groups[id] + if !exists { + return fmt.Errorf("group with ID '%s' does not exist", id) + } + + // Remove group + delete(fm.groups, id) + // log.Info().Str("groupID", id).Msg("Removed file group") + + return nil +} + +// GetGroups returns a list of all file group IDs +func (fm *FileMonitor) GetGroups() []*FileGroup { + fm.mutex.RLock() + defer fm.mutex.RUnlock() + + groups := make([]*FileGroup, 0, len(fm.groups)) + for _, group := range fm.groups { + groups = append(groups, group) + } + + return groups +} + +// GetGroup returns the specified file group +func (fm *FileMonitor) GetGroup(id string) (*FileGroup, bool) { + fm.mutex.RLock() + defer fm.mutex.RUnlock() + + group, exists := fm.groups[id] + return group, exists +} + +// Start starts the file monitor +func (fm *FileMonitor) Start() error { + // Check if already running + fm.stateMutex.Lock() + if fm.isRunning { + fm.stateMutex.Unlock() + return errors.New("file monitor is already running") + } + + // Create new watcher + watcher, err := fsnotify.NewWatcher() + if err != nil { + fm.stateMutex.Unlock() + return fmt.Errorf("failed to create watcher: %w", err) + } + fm.watcher = watcher + + // Reset stop channel + fm.stopCh = make(chan struct{}) + + // Get groups to monitor (without holding the state lock) + fm.mutex.RLock() + groups := make([]*FileGroup, 0, len(fm.groups)) + for _, group := range fm.groups { + groups = append(groups, group) + } + fm.mutex.RUnlock() + + // Reset monitored directories + fm.mutex.Lock() + fm.watchDirs = make(map[string]bool) + fm.mutex.Unlock() + + // Mark as running before setting up watches + fm.isRunning = true + fm.stateMutex.Unlock() + + // Set up monitoring for all groups (without holding any locks) + for _, group := range groups { + if err := fm.setupWatchForGroup(group); err != nil { + // Clean up resources on failure + _ = fm.watcher.Close() + + // Reset running state + fm.stateMutex.Lock() + fm.watcher = nil + fm.isRunning = false + fm.stateMutex.Unlock() + + return fmt.Errorf("failed to setup watch for group '%s': %w", group.ID, err) + } + } + + // Start watch loop + fm.wg.Add(1) + go fm.watchLoop() + + // log.Info().Msg("File monitor started") + return nil +} + +// Stop stops the file monitor +func (fm *FileMonitor) Stop() error { + // Check if already stopped + fm.stateMutex.Lock() + if !fm.isRunning { + fm.stateMutex.Unlock() + return errors.New("file monitor is not running") + } + + // Get watcher reference before changing state + watcher := fm.watcher + + // Send stop signal + close(fm.stopCh) + + // Mark as not running + fm.isRunning = false + fm.stateMutex.Unlock() + + // Wait for all goroutines to exit + fm.wg.Wait() + + // Close watcher + if watcher != nil { + if err := watcher.Close(); err != nil { + return fmt.Errorf("failed to close watcher: %w", err) + } + + fm.stateMutex.Lock() + fm.watcher = nil + fm.stateMutex.Unlock() + } + + // log.Info().Msg("File monitor stopped") + return nil +} + +// IsRunning returns whether the file monitor is running +func (fm *FileMonitor) IsRunning() bool { + fm.stateMutex.RLock() + defer fm.stateMutex.RUnlock() + return fm.isRunning +} + +// addWatchDir adds a directory to monitoring +func (fm *FileMonitor) addWatchDir(dirPath string) error { + // Check global blacklist first + fm.mutex.RLock() + for _, pattern := range fm.blacklist { + if strings.Contains(dirPath, pattern) { + fm.mutex.RUnlock() + log.Debug().Str("dir", dirPath).Msg("Skipping blacklisted directory") + return nil + } + } + fm.mutex.RUnlock() + + fm.mutex.Lock() + defer fm.mutex.Unlock() + + // Check if directory is already being monitored + if _, watched := fm.watchDirs[dirPath]; watched { + return nil // Already monitored, no need to add again + } + + // Add to monitoring + if err := fm.watcher.Add(dirPath); err != nil { + return fmt.Errorf("failed to watch directory '%s': %w", dirPath, err) + } + + fm.watchDirs[dirPath] = true + // log.Debug().Str("dir", dirPath).Msg("Added watch for directory") + return nil +} + +// setupWatchForGroup sets up monitoring for a file group +func (fm *FileMonitor) setupWatchForGroup(group *FileGroup) error { + // Check if file monitor is running + if !fm.IsRunning() { + return errors.New("file monitor is not running") + } + + // Find directories containing matching files + matchingDirs, err := group.ListMatchingDirectories() + if err != nil { + return fmt.Errorf("failed to list matching directories: %w", err) + } + + // Always watch the root directory to catch new files + rootDir := filepath.Clean(group.RootDir) + if err := fm.addWatchDir(rootDir); err != nil { + return err + } + + // Watch directories containing matching files + for dir := range matchingDirs { + if err := fm.addWatchDir(dir); err != nil { + return err + } + } + + return nil +} + +// RefreshWatches updates the watched directories based on current matching files +func (fm *FileMonitor) RefreshWatches() error { + // Check if file monitor is running + if !fm.IsRunning() { + return errors.New("file monitor is not running") + } + + // Get groups to refresh + fm.mutex.RLock() + groups := make([]*FileGroup, 0, len(fm.groups)) + for _, group := range fm.groups { + groups = append(groups, group) + } + fm.mutex.RUnlock() + + // Reset watched directories + fm.mutex.Lock() + oldWatchDirs := fm.watchDirs + fm.watchDirs = make(map[string]bool) + fm.mutex.Unlock() + + // Setup watches for each group + for _, group := range groups { + if err := fm.setupWatchForGroup(group); err != nil { + return fmt.Errorf("failed to refresh watches for group '%s': %w", group.ID, err) + } + } + + // Remove watches for directories no longer needed + for dir := range oldWatchDirs { + fm.mutex.RLock() + _, stillWatched := fm.watchDirs[dir] + fm.mutex.RUnlock() + + if !stillWatched && fm.watcher != nil { + _ = fm.watcher.Remove(dir) + log.Debug().Str("dir", dir).Msg("Removed watch for directory") + } + } + + return nil +} + +// watchLoop monitors for file system events +func (fm *FileMonitor) watchLoop() { + defer fm.wg.Done() + + for { + select { + case <-fm.stopCh: + return + + case event, ok := <-fm.watcher.Events: + if !ok { + // Channel closed, exit loop + return + } + + // Handle directory creation events to add new watches + info, err := os.Stat(event.Name) + if err == nil && info.IsDir() && event.Op&(fsnotify.Create|fsnotify.Rename) != 0 { + // Add new directory to monitoring + if err := fm.addWatchDir(event.Name); err != nil { + log.Error(). + Str("dir", event.Name). + Err(err). + Msg("Error watching new directory") + } + continue + } + + // For file creation/modification, check if we need to watch its directory + if event.Op&(fsnotify.Create|fsnotify.Write) != 0 { + // Check if this file matches any group + shouldWatch := false + + fm.mutex.RLock() + for _, group := range fm.groups { + if group.Match(event.Name) { + shouldWatch = true + break + } + } + fm.mutex.RUnlock() + + // If file matches, ensure its directory is watched + if shouldWatch { + dir := filepath.Dir(event.Name) + if err := fm.addWatchDir(dir); err != nil { + log.Error(). + Str("dir", dir). + Err(err). + Msg("Error watching directory of matching file") + } + } + } + + // Forward event to all groups + fm.forwardEventToGroups(event) + + case err, ok := <-fm.watcher.Errors: + if !ok { + // Channel closed, exit loop + return + } + log.Error().Err(err).Msg("Watcher error") + } + } +} + +// forwardEventToGroups forwards file events to matching groups +func (fm *FileMonitor) forwardEventToGroups(event fsnotify.Event) { + // Get a copy of groups to avoid holding lock during processing + fm.mutex.RLock() + groupsCopy := make([]*FileGroup, 0, len(fm.groups)) + for _, group := range fm.groups { + groupsCopy = append(groupsCopy, group) + } + fm.mutex.RUnlock() + + // Forward to all groups - each group will check if the event is relevant + for _, group := range groupsCopy { + group.HandleEvent(event) + } +}