auto decrypt (#44)

This commit is contained in:
Sarv
2025-04-16 01:02:29 +08:00
committed by GitHub
parent f2aa923e99
commit 25d0b394e2
20 changed files with 2903 additions and 712 deletions

View File

@@ -2,14 +2,17 @@ package chatlog
import (
"fmt"
"path/filepath"
"runtime"
"time"
"github.com/sjzar/chatlog/internal/chatlog/ctx"
"github.com/sjzar/chatlog/internal/ui/footer"
"github.com/sjzar/chatlog/internal/ui/form"
"github.com/sjzar/chatlog/internal/ui/help"
"github.com/sjzar/chatlog/internal/ui/infobar"
"github.com/sjzar/chatlog/internal/ui/menu"
"github.com/sjzar/chatlog/internal/wechat"
"github.com/gdamore/tcell/v2"
"github.com/rivo/tview"
@@ -54,6 +57,8 @@ func NewApp(ctx *ctx.Context, m *Manager) *App {
app.initMenu()
app.updateMenuItemsState()
return app
}
@@ -91,6 +96,33 @@ func (a *App) Stop() {
a.Application.Stop()
}
func (a *App) updateMenuItemsState() {
// 查找并更新自动解密菜单项
for _, item := range a.menu.GetItems() {
// 更新自动解密菜单项
if item.Index == 5 {
if a.ctx.AutoDecrypt {
item.Name = "停止自动解密"
item.Description = "停止监控数据目录更新,不再自动解密新增数据"
} else {
item.Name = "开启自动解密"
item.Description = "监控数据目录更新,自动解密新增数据"
}
}
// 更新HTTP服务菜单项
if item.Index == 4 {
if a.ctx.HTTPEnabled {
item.Name = "停止 HTTP 服务"
item.Description = "停止本地 HTTP & MCP 服务器"
} else {
item.Name = "启动 HTTP 服务"
item.Description = "启动本地 HTTP & MCP 服务器"
}
}
}
}
func (a *App) switchTab(step int) {
index := (a.activeTab + step) % a.tabCount
if index < 0 {
@@ -109,17 +141,29 @@ func (a *App) refresh() {
case <-a.stopRefresh:
return
case <-tick.C:
if a.ctx.AutoDecrypt || a.ctx.HTTPEnabled {
a.m.RefreshSession()
}
a.infoBar.UpdateAccount(a.ctx.Account)
a.infoBar.UpdateBasicInfo(a.ctx.PID, a.ctx.FullVersion, a.ctx.ExePath)
a.infoBar.UpdateStatus(a.ctx.Status)
a.infoBar.UpdateDataKey(a.ctx.DataKey)
a.infoBar.UpdatePlatform(a.ctx.Platform)
a.infoBar.UpdateDataUsageDir(a.ctx.DataUsage, a.ctx.DataDir)
a.infoBar.UpdateWorkUsageDir(a.ctx.WorkUsage, a.ctx.WorkDir)
if a.ctx.LastSession.Unix() > 1000000000 {
a.infoBar.UpdateSession(a.ctx.LastSession.Format("2006-01-02 15:04:05"))
}
if a.ctx.HTTPEnabled {
a.infoBar.UpdateHTTPServer(fmt.Sprintf("[green][已启动][white] [%s]", a.ctx.HTTPAddr))
} else {
a.infoBar.UpdateHTTPServer("[未启动]")
}
if a.ctx.AutoDecrypt {
a.infoBar.UpdateAutoDecrypt("[green][已开启][white]")
} else {
a.infoBar.UpdateAutoDecrypt("[未开启]")
}
a.Draw()
}
@@ -257,11 +301,11 @@ func (a *App) initMenu() {
} else {
// 启动成功
modal.SetText("已启动 HTTP 服务")
// 更改菜单项名称
i.Name = "停止 HTTP 服务"
i.Description = "停止本地 HTTP & MCP 服务器"
}
// 更改菜单项名称
a.updateMenuItemsState()
// 添加确认按钮
modal.AddButtons([]string{"OK"})
modal.SetDoneFunc(func(buttonIndex int, buttonLabel string) {
@@ -288,11 +332,89 @@ func (a *App) initMenu() {
} else {
// 停止成功
modal.SetText("已停止 HTTP 服务")
// 更改菜单项名称
i.Name = "启动 HTTP 服务"
i.Description = "启动本地 HTTP & MCP 服务器"
}
// 更改菜单项名称
a.updateMenuItemsState()
// 添加确认按钮
modal.AddButtons([]string{"OK"})
modal.SetDoneFunc(func(buttonIndex int, buttonLabel string) {
a.mainPages.RemovePage("modal")
})
a.SetFocus(modal)
})
}()
}
},
}
autoDecrypt := &menu.Item{
Index: 5,
Name: "开启自动解密",
Description: "自动解密新增的数据文件",
Selected: func(i *menu.Item) {
modal := tview.NewModal()
// 根据当前自动解密状态执行不同操作
if !a.ctx.AutoDecrypt {
// 自动解密未开启,开启自动解密
modal.SetText("正在开启自动解密...")
a.mainPages.AddPage("modal", modal, true, true)
a.SetFocus(modal)
// 在后台开启自动解密
go func() {
err := a.m.StartAutoDecrypt()
// 在主线程中更新UI
a.QueueUpdateDraw(func() {
if err != nil {
// 开启失败
modal.SetText("开启自动解密失败: " + err.Error())
} else {
// 开启成功
if a.ctx.Version == 3 {
modal.SetText("已开启自动解密\n3.x版本数据文件更新不及时有低延迟需求请使用4.0版本")
} else {
modal.SetText("已开启自动解密")
}
}
// 更改菜单项名称
a.updateMenuItemsState()
// 添加确认按钮
modal.AddButtons([]string{"OK"})
modal.SetDoneFunc(func(buttonIndex int, buttonLabel string) {
a.mainPages.RemovePage("modal")
})
a.SetFocus(modal)
})
}()
} else {
// 自动解密已开启,停止自动解密
modal.SetText("正在停止自动解密...")
a.mainPages.AddPage("modal", modal, true, true)
a.SetFocus(modal)
// 在后台停止自动解密
go func() {
err := a.m.StopAutoDecrypt()
// 在主线程中更新UI
a.QueueUpdateDraw(func() {
if err != nil {
// 停止失败
modal.SetText("停止自动解密失败: " + err.Error())
} else {
// 停止成功
modal.SetText("已停止自动解密")
}
// 更改菜单项名称
a.updateMenuItemsState()
// 添加确认按钮
modal.AddButtons([]string{"OK"})
modal.SetDoneFunc(func(buttonIndex int, buttonLabel string) {
@@ -306,19 +428,28 @@ func (a *App) initMenu() {
}
setting := &menu.Item{
Index: 5,
Index: 6,
Name: "设置",
Description: "设置应用程序选项",
Selected: a.settingSelected,
}
a.menu.AddItem(setting)
selectAccount := &menu.Item{
Index: 7,
Name: "切换账号",
Description: "切换当前操作的账号,可以选择进程或历史账号",
Selected: a.selectAccountSelected,
}
a.menu.AddItem(getDataKey)
a.menu.AddItem(decryptData)
a.menu.AddItem(httpServer)
a.menu.AddItem(autoDecrypt)
a.menu.AddItem(setting)
a.menu.AddItem(selectAccount)
a.menu.AddItem(&menu.Item{
Index: 6,
Index: 8,
Name: "退出",
Description: "退出程序",
Selected: func(i *menu.Item) {
@@ -347,6 +478,16 @@ func (a *App) settingSelected(i *menu.Item) {
description: "配置数据解密后的存储目录",
action: a.settingWorkDir,
},
{
name: "设置数据密钥",
description: "配置数据解密密钥",
action: a.settingDataKey,
},
{
name: "设置数据目录",
description: "配置微信数据文件所在目录",
action: a.settingDataDir,
},
}
subMenu := menu.NewSubMenu("设置")
@@ -370,43 +511,279 @@ func (a *App) settingSelected(i *menu.Item) {
// settingHTTPPort 设置 HTTP 端口
func (a *App) settingHTTPPort() {
// 实现端口设置逻辑
// 这里可以使用 tview.InputField 让用户输入端口
form := tview.NewForm().
AddInputField("地址", a.ctx.HTTPAddr, 20, nil, func(text string) {
a.m.SetHTTPAddr(text)
}).
AddButton("保存", func() {
a.mainPages.RemovePage("submenu2")
a.showInfo("HTTP 地址已设置为 " + a.ctx.HTTPAddr)
}).
AddButton("取消", func() {
a.mainPages.RemovePage("submenu2")
})
form.SetBorder(true).SetTitle("设置 HTTP 地址")
// 使用我们的自定义表单组件
formView := form.NewForm("设置 HTTP 地址")
a.mainPages.AddPage("submenu2", form, true, true)
a.SetFocus(form)
// 临时存储用户输入的值
tempHTTPAddr := a.ctx.HTTPAddr
// 添加输入字段 - 不再直接设置HTTP地址而是更新临时变量
formView.AddInputField("地址", tempHTTPAddr, 0, nil, func(text string) {
tempHTTPAddr = text // 只更新临时变量
})
// 添加按钮 - 点击保存时才设置HTTP地址
formView.AddButton("保存", func() {
a.m.SetHTTPAddr(tempHTTPAddr) // 在这里设置HTTP地址
a.mainPages.RemovePage("submenu2")
a.showInfo("HTTP 地址已设置为 " + a.ctx.HTTPAddr)
})
formView.AddButton("取消", func() {
a.mainPages.RemovePage("submenu2")
})
a.mainPages.AddPage("submenu2", formView, true, true)
a.SetFocus(formView)
}
// settingWorkDir 设置工作目录
func (a *App) settingWorkDir() {
// 实现工作目录设置逻辑
form := tview.NewForm().
AddInputField("工作目录", a.ctx.WorkDir, 40, nil, func(text string) {
a.ctx.SetWorkDir(text)
}).
AddButton("保存", func() {
a.mainPages.RemovePage("submenu2")
a.showInfo("工作目录已设置为 " + a.ctx.WorkDir)
}).
AddButton("取消", func() {
a.mainPages.RemovePage("submenu2")
})
form.SetBorder(true).SetTitle("设置工作目录")
// 使用我们的自定义表单组件
formView := form.NewForm("设置工作目录")
a.mainPages.AddPage("submenu2", form, true, true)
a.SetFocus(form)
// 临时存储用户输入的值
tempWorkDir := a.ctx.WorkDir
// 添加输入字段 - 不再直接设置工作目录,而是更新临时变量
formView.AddInputField("工作目录", tempWorkDir, 0, nil, func(text string) {
tempWorkDir = text // 只更新临时变量
})
// 添加按钮 - 点击保存时才设置工作目录
formView.AddButton("保存", func() {
a.ctx.SetWorkDir(tempWorkDir) // 在这里设置工作目录
a.mainPages.RemovePage("submenu2")
a.showInfo("工作目录已设置为 " + a.ctx.WorkDir)
})
formView.AddButton("取消", func() {
a.mainPages.RemovePage("submenu2")
})
a.mainPages.AddPage("submenu2", formView, true, true)
a.SetFocus(formView)
}
// settingDataKey 设置数据密钥
func (a *App) settingDataKey() {
// 使用我们的自定义表单组件
formView := form.NewForm("设置数据密钥")
// 临时存储用户输入的值
tempDataKey := a.ctx.DataKey
// 添加输入字段 - 不直接设置数据密钥,而是更新临时变量
formView.AddInputField("数据密钥", tempDataKey, 0, nil, func(text string) {
tempDataKey = text // 只更新临时变量
})
// 添加按钮 - 点击保存时才设置数据密钥
formView.AddButton("保存", func() {
a.ctx.DataKey = tempDataKey // 设置数据密钥
a.mainPages.RemovePage("submenu2")
a.showInfo("数据密钥已设置")
})
formView.AddButton("取消", func() {
a.mainPages.RemovePage("submenu2")
})
a.mainPages.AddPage("submenu2", formView, true, true)
a.SetFocus(formView)
}
// settingDataDir 设置数据目录
func (a *App) settingDataDir() {
// 使用我们的自定义表单组件
formView := form.NewForm("设置数据目录")
// 临时存储用户输入的值
tempDataDir := a.ctx.DataDir
// 添加输入字段 - 不直接设置数据目录,而是更新临时变量
formView.AddInputField("数据目录", tempDataDir, 0, nil, func(text string) {
tempDataDir = text // 只更新临时变量
})
// 添加按钮 - 点击保存时才设置数据目录
formView.AddButton("保存", func() {
a.ctx.DataDir = tempDataDir // 设置数据目录
a.mainPages.RemovePage("submenu2")
a.showInfo("数据目录已设置为 " + a.ctx.DataDir)
})
formView.AddButton("取消", func() {
a.mainPages.RemovePage("submenu2")
})
a.mainPages.AddPage("submenu2", formView, true, true)
a.SetFocus(formView)
}
// selectAccountSelected 处理切换账号菜单项的选择事件
func (a *App) selectAccountSelected(i *menu.Item) {
// 创建子菜单
subMenu := menu.NewSubMenu("切换账号")
// 添加微信进程
instances := a.m.wechat.GetWeChatInstances()
if len(instances) > 0 {
// 添加实例标题
subMenu.AddItem(&menu.Item{
Index: 0,
Name: "--- 微信进程 ---",
Description: "",
Hidden: false,
Selected: nil,
})
// 添加实例列表
for idx, instance := range instances {
// 创建一个实例描述
description := fmt.Sprintf("版本: %s 目录: %s", instance.FullVersion, instance.DataDir)
// 标记当前选中的实例
name := fmt.Sprintf("%s [%d]", instance.Name, instance.PID)
if a.ctx.Current != nil && a.ctx.Current.PID == instance.PID {
name = name + " [当前]"
}
// 创建菜单项
instanceItem := &menu.Item{
Index: idx + 1,
Name: name,
Description: description,
Hidden: false,
Selected: func(instance *wechat.Account) func(*menu.Item) {
return func(*menu.Item) {
// 如果是当前账号,则无需切换
if a.ctx.Current != nil && a.ctx.Current.PID == instance.PID {
a.mainPages.RemovePage("submenu")
a.showInfo("已经是当前账号")
return
}
// 显示切换中的模态框
modal := tview.NewModal().SetText("正在切换账号...")
a.mainPages.AddPage("modal", modal, true, true)
a.SetFocus(modal)
// 在后台执行切换操作
go func() {
err := a.m.Switch(instance, "")
// 在主线程中更新UI
a.QueueUpdateDraw(func() {
a.mainPages.RemovePage("modal")
a.mainPages.RemovePage("submenu")
if err != nil {
// 切换失败
a.showError(fmt.Errorf("切换账号失败: %v", err))
} else {
// 切换成功
a.showInfo("切换账号成功")
// 更新菜单状态
a.updateMenuItemsState()
}
})
}()
}
}(instance),
}
subMenu.AddItem(instanceItem)
}
}
// 添加历史账号
if len(a.ctx.History) > 0 {
// 添加历史账号标题
subMenu.AddItem(&menu.Item{
Index: 100,
Name: "--- 历史账号 ---",
Description: "",
Hidden: false,
Selected: nil,
})
// 添加历史账号列表
idx := 101
for account, hist := range a.ctx.History {
// 创建一个账号描述
description := fmt.Sprintf("版本: %s 目录: %s", hist.FullVersion, hist.DataDir)
// 标记当前选中的账号
name := account
if name == "" {
name = filepath.Base(hist.DataDir)
}
if a.ctx.DataDir == hist.DataDir {
name = name + " [当前]"
}
// 创建菜单项
histItem := &menu.Item{
Index: idx,
Name: name,
Description: description,
Hidden: false,
Selected: func(account string) func(*menu.Item) {
return func(*menu.Item) {
// 如果是当前账号,则无需切换
if a.ctx.Current != nil && a.ctx.DataDir == a.ctx.History[account].DataDir {
a.mainPages.RemovePage("submenu")
a.showInfo("已经是当前账号")
return
}
// 显示切换中的模态框
modal := tview.NewModal().SetText("正在切换账号...")
a.mainPages.AddPage("modal", modal, true, true)
a.SetFocus(modal)
// 在后台执行切换操作
go func() {
err := a.m.Switch(nil, account)
// 在主线程中更新UI
a.QueueUpdateDraw(func() {
a.mainPages.RemovePage("modal")
a.mainPages.RemovePage("submenu")
if err != nil {
// 切换失败
a.showError(fmt.Errorf("切换账号失败: %v", err))
} else {
// 切换成功
a.showInfo("切换账号成功")
// 更新菜单状态
a.updateMenuItemsState()
}
})
}()
}
}(account),
}
idx++
subMenu.AddItem(histItem)
}
}
// 如果没有账号可选择
if len(a.ctx.History) == 0 && len(instances) == 0 {
subMenu.AddItem(&menu.Item{
Index: 1,
Name: "无可用账号",
Description: "未检测到微信进程或历史账号",
Hidden: false,
Selected: nil,
})
}
// 显示子菜单
a.mainPages.AddPage("submenu", subMenu, true, true)
a.SetFocus(subMenu)
}
// showModal 显示一个模态对话框

View File

@@ -2,6 +2,7 @@ package ctx
import (
"sync"
"time"
"github.com/sjzar/chatlog/internal/chatlog/conf"
"github.com/sjzar/chatlog/internal/wechat"
@@ -33,6 +34,10 @@ type Context struct {
HTTPEnabled bool
HTTPAddr string
// 自动解密
AutoDecrypt bool
LastSession time.Time
// 当前选中的微信实例
Current *wechat.Account
PID int
@@ -63,6 +68,10 @@ func (c *Context) loadConfig() {
func (c *Context) SwitchHistory(account string) {
c.mu.Lock()
defer c.mu.Unlock()
c.Current = nil
c.PID = 0
c.ExePath = ""
c.Status = ""
history, ok := c.History[account]
if ok {
c.Account = history.Account
@@ -153,6 +162,13 @@ func (c *Context) SetDataDir(dir string) {
c.Refresh()
}
func (c *Context) SetAutoDecrypt(enabled bool) {
c.mu.Lock()
defer c.mu.Unlock()
c.AutoDecrypt = enabled
c.UpdateConfig()
}
// 更新配置
func (c *Context) UpdateConfig() {
pconf := conf.ProcessConfig{

View File

@@ -84,11 +84,12 @@ func (s *Service) Stop() error {
}
// 使用超时上下文优雅关闭
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := s.server.Shutdown(ctx); err != nil {
return errors.HTTPShutDown(err)
log.Debug().Err(err).Msg("Failed to shutdown HTTP server")
return nil
}
log.Info().Msg("HTTP server stopped")

View File

@@ -6,12 +6,14 @@ import (
"path/filepath"
"strings"
"github.com/rs/zerolog/log"
"github.com/sjzar/chatlog/internal/chatlog/conf"
"github.com/sjzar/chatlog/internal/chatlog/ctx"
"github.com/sjzar/chatlog/internal/chatlog/database"
"github.com/sjzar/chatlog/internal/chatlog/http"
"github.com/sjzar/chatlog/internal/chatlog/mcp"
"github.com/sjzar/chatlog/internal/chatlog/wechat"
iwechat "github.com/sjzar/chatlog/internal/wechat"
"github.com/sjzar/chatlog/pkg/util"
"github.com/sjzar/chatlog/pkg/util/dat2img"
)
@@ -79,6 +81,33 @@ func (m *Manager) Run() error {
return nil
}
func (m *Manager) Switch(info *iwechat.Account, history string) error {
if m.ctx.AutoDecrypt {
if err := m.StopAutoDecrypt(); err != nil {
return err
}
}
if m.ctx.HTTPEnabled {
if err := m.stopService(); err != nil {
return err
}
}
if info != nil {
m.ctx.SwitchCurrent(info)
} else {
m.ctx.SwitchHistory(history)
}
if m.ctx.HTTPEnabled {
// 启动HTTP服务
if err := m.StartService(); err != nil {
log.Info().Err(err).Msg("启动服务失败")
m.StopService()
}
}
return nil
}
func (m *Manager) StartService() error {
// 按依赖顺序启动服务
@@ -109,6 +138,17 @@ func (m *Manager) StartService() error {
}
func (m *Manager) StopService() error {
if err := m.stopService(); err != nil {
return err
}
// 更新状态
m.ctx.SetHTTPEnabled(false)
return nil
}
func (m *Manager) stopService() error {
// 按依赖的反序停止服务
var errs []error
@@ -124,9 +164,6 @@ func (m *Manager) StopService() error {
errs = append(errs, err)
}
// 更新状态
m.ctx.SetHTTPEnabled(false)
// 如果有错误,返回第一个错误
if len(errs) > 0 {
return errs[0]
@@ -138,7 +175,7 @@ func (m *Manager) StopService() error {
func (m *Manager) SetHTTPAddr(text string) error {
var addr string
if util.IsNumeric(text) {
addr = fmt.Sprintf("0.0.0.0:%s", text)
addr = fmt.Sprintf("127.0.0.1:%s", text)
} else if strings.HasPrefix(text, "http://") {
addr = strings.TrimPrefix(text, "http://")
} else if strings.HasPrefix(text, "https://") {
@@ -175,7 +212,7 @@ func (m *Manager) DecryptDBFiles() error {
m.ctx.WorkDir = util.DefaultWorkDir(m.ctx.Account)
}
if err := m.wechat.DecryptDBFiles(m.ctx.DataDir, m.ctx.WorkDir, m.ctx.DataKey, m.ctx.Platform, m.ctx.Version); err != nil {
if err := m.wechat.DecryptDBFiles(); err != nil {
return err
}
m.ctx.Refresh()
@@ -183,6 +220,48 @@ func (m *Manager) DecryptDBFiles() error {
return nil
}
func (m *Manager) StartAutoDecrypt() error {
if m.ctx.DataKey == "" || m.ctx.DataDir == "" {
return fmt.Errorf("请先获取密钥")
}
if m.ctx.WorkDir == "" {
return fmt.Errorf("请先执行解密数据")
}
if err := m.wechat.StartAutoDecrypt(); err != nil {
return err
}
m.ctx.SetAutoDecrypt(true)
return nil
}
func (m *Manager) StopAutoDecrypt() error {
if err := m.wechat.StopAutoDecrypt(); err != nil {
return err
}
m.ctx.SetAutoDecrypt(false)
return nil
}
func (m *Manager) RefreshSession() error {
if m.db.GetDB() == nil {
if err := m.db.Start(); err != nil {
return err
}
}
resp, err := m.db.GetSessions("", 1, 0)
if err != nil {
return err
}
if len(resp.Items) == 0 {
return nil
}
m.ctx.LastSession = resp.Items[0].NTime
return nil
}
func (m *Manager) CommandKey(pid int) (string, error) {
instances := m.wechat.GetWeChatInstances()
if len(instances) == 0 {
@@ -216,8 +295,12 @@ func (m *Manager) CommandDecrypt(dataDir string, workDir string, key string, pla
if workDir == "" {
workDir = util.DefaultWorkDir(filepath.Base(filepath.Dir(dataDir)))
}
if err := m.wechat.DecryptDBFiles(dataDir, workDir, key, platform, version); err != nil {
m.ctx.DataDir = dataDir
m.ctx.WorkDir = workDir
m.ctx.DataKey = key
m.ctx.Platform = platform
m.ctx.Version = version
if err := m.wechat.DecryptDBFiles(); err != nil {
return err
}

View File

@@ -5,24 +5,38 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"github.com/rs/zerolog/log"
"github.com/sjzar/chatlog/internal/chatlog/ctx"
"github.com/sjzar/chatlog/internal/errors"
"github.com/sjzar/chatlog/internal/wechat"
"github.com/sjzar/chatlog/internal/wechat/decrypt"
"github.com/sjzar/chatlog/pkg/filemonitor"
"github.com/sjzar/chatlog/pkg/util"
)
"github.com/rs/zerolog/log"
var (
DebounceTime = 1 * time.Second
MaxWaitTime = 10 * time.Second
)
type Service struct {
ctx *ctx.Context
ctx *ctx.Context
lastEvents map[string]time.Time
pendingActions map[string]bool
mutex sync.Mutex
fm *filemonitor.FileMonitor
}
func NewService(ctx *ctx.Context) *Service {
return &Service{
ctx: ctx,
ctx: ctx,
lastEvents: make(map[string]time.Time),
pendingActions: make(map[string]bool),
}
}
@@ -46,90 +60,128 @@ func (s *Service) GetDataKey(info *wechat.Account) (string, error) {
return key, nil
}
// FindDBFiles finds all .db files in the specified directory
func (s *Service) FindDBFiles(rootDir string, recursive bool) ([]string, error) {
// Check if directory exists
info, err := os.Stat(rootDir)
func (s *Service) StartAutoDecrypt() error {
dbGroup, err := filemonitor.NewFileGroup("wechat", s.ctx.DataDir, `.*\.db$`, []string{"fts"})
if err != nil {
return nil, fmt.Errorf("cannot access directory %s: %w", rootDir, err)
return err
}
dbGroup.AddCallback(s.DecryptFileCallback)
if !info.IsDir() {
return nil, fmt.Errorf("%s is not a directory", rootDir)
s.fm = filemonitor.NewFileMonitor()
s.fm.AddGroup(dbGroup)
if err := s.fm.Start(); err != nil {
log.Debug().Err(err).Msg("failed to start file monitor")
return err
}
return nil
}
var dbFiles []string
// Define walk function
walkFunc := func(path string, info os.FileInfo, err error) error {
if err != nil {
// If a file or directory can't be accessed, log the error but continue
log.Err(err).Msgf("Warning: Cannot access %s", path)
return nil
}
// If it's a directory and not the root directory, and we're not recursively searching, skip it
if info.IsDir() && path != rootDir && !recursive {
return filepath.SkipDir
}
// Check if file extension is .db
if !info.IsDir() && strings.ToLower(filepath.Ext(path)) == ".db" {
dbFiles = append(dbFiles, path)
func (s *Service) StopAutoDecrypt() error {
if s.fm != nil {
if err := s.fm.Stop(); err != nil {
return err
}
}
s.fm = nil
return nil
}
func (s *Service) DecryptFileCallback(event fsnotify.Event) error {
if event.Op.Has(fsnotify.Chmod) || !event.Op.Has(fsnotify.Write) {
return nil
}
// Start traversal
err = filepath.Walk(rootDir, walkFunc)
if err != nil {
return nil, fmt.Errorf("error traversing directory: %w", err)
s.mutex.Lock()
s.lastEvents[event.Name] = time.Now()
if !s.pendingActions[event.Name] {
s.pendingActions[event.Name] = true
s.mutex.Unlock()
go s.waitAndProcess(event.Name)
} else {
s.mutex.Unlock()
}
if len(dbFiles) == 0 {
return nil, fmt.Errorf("no .db files found")
}
return dbFiles, nil
return nil
}
func (s *Service) DecryptDBFiles(dataDir string, workDir string, key string, platform string, version int) error {
func (s *Service) waitAndProcess(dbFile string) {
start := time.Now()
for {
time.Sleep(DebounceTime)
ctx := context.Background()
s.mutex.Lock()
lastEventTime := s.lastEvents[dbFile]
elapsed := time.Since(lastEventTime)
totalElapsed := time.Since(start)
dbfiles, err := s.FindDBFiles(dataDir, true)
if elapsed >= DebounceTime || totalElapsed >= MaxWaitTime {
s.pendingActions[dbFile] = false
s.mutex.Unlock()
log.Debug().Msgf("Processing file: %s", dbFile)
s.DecryptDBFile(dbFile)
return
}
s.mutex.Unlock()
}
}
func (s *Service) DecryptDBFile(dbFile string) error {
decryptor, err := decrypt.NewDecryptor(s.ctx.Platform, s.ctx.Version)
if err != nil {
return err
}
decryptor, err := decrypt.NewDecryptor(platform, version)
if err != nil {
output := filepath.Join(s.ctx.WorkDir, dbFile[len(s.ctx.DataDir):])
if err := util.PrepareDir(filepath.Dir(output)); err != nil {
return err
}
for _, dbfile := range dbfiles {
output := filepath.Join(workDir, dbfile[len(dataDir):])
if err := util.PrepareDir(filepath.Dir(output)); err != nil {
return err
outputTemp := output + ".tmp"
outputFile, err := os.Create(outputTemp)
if err != nil {
return fmt.Errorf("failed to create output file: %v", err)
}
defer func() {
outputFile.Close()
if err := os.Rename(outputTemp, output); err != nil {
log.Debug().Err(err).Msgf("failed to rename %s to %s", outputTemp, output)
}
}()
outputFile, err := os.Create(output)
if err != nil {
return fmt.Errorf("failed to create output file: %v", err)
}
defer outputFile.Close()
if err := decryptor.Decrypt(ctx, dbfile, key, outputFile); err != nil {
log.Err(err).Msgf("failed to decrypt %s", dbfile)
if err == errors.ErrAlreadyDecrypted {
if data, err := os.ReadFile(dbfile); err == nil {
outputFile.Write(data)
}
continue
if err := decryptor.Decrypt(context.Background(), dbFile, s.ctx.DataKey, outputFile); err != nil {
if err == errors.ErrAlreadyDecrypted {
if data, err := os.ReadFile(dbFile); err == nil {
outputFile.Write(data)
}
return nil
}
log.Err(err).Msgf("failed to decrypt %s", dbFile)
return err
}
log.Debug().Msgf("Decrypted %s to %s", dbFile, output)
return nil
}
func (s *Service) DecryptDBFiles() error {
dbGroup, err := filemonitor.NewFileGroup("wechat", s.ctx.DataDir, `.*\.db$`, []string{"fts"})
if err != nil {
return err
}
dbFiles, err := dbGroup.List()
if err != nil {
return err
}
for _, dbFile := range dbFiles {
if err := s.DecryptDBFile(dbFile); err != nil {
log.Debug().Msgf("DecryptDBFile %s failed: %v", dbFile, err)
continue
// return err
}
}

View File

@@ -60,3 +60,7 @@ func ContactNotFound(key string) *Error {
func InitCacheFailed(cause error) *Error {
return New(cause, http.StatusInternalServerError, "init cache failed").WithStack()
}
func FileGroupNotFound(name string) *Error {
return Newf(nil, http.StatusNotFound, "file group not found: %s", name).WithStack()
}

258
internal/ui/form/form.go Normal file
View File

@@ -0,0 +1,258 @@
package form
import (
"fmt"
"github.com/gdamore/tcell/v2"
"github.com/rivo/tview"
"github.com/sjzar/chatlog/internal/ui/style"
)
const (
// DialogPadding dialog inner padding.
DialogPadding = 3
// DialogHelpHeight dialog help text height.
DialogHelpHeight = 1
// DialogMinWidth dialog min width.
DialogMinWidth = 40
// FormHeightOffset form height offset for border.
FormHeightOffset = 3
// 额外的宽度补偿,类似于 submenu 的 cmdWidthOffset
formWidthOffset = 10
)
// Form is a modal form component with a title, form fields, and help text.
type Form struct {
*tview.Box
title string
layout *tview.Flex
form *tview.Form
helpText *tview.TextView
width int
height int
cancelHandler func()
fields []formField // 存储字段信息以便重新计算宽度
}
// formField 存储表单字段的信息
type formField struct {
label string
value string
fieldWidth int
}
// NewForm creates a new form with the given title.
func NewForm(title string) *Form {
f := &Form{
Box: tview.NewBox(),
title: title,
layout: tview.NewFlex().SetDirection(tview.FlexRow),
form: tview.NewForm(),
fields: make([]formField, 0),
}
// 设置表单样式
f.form.SetBorderPadding(1, 1, 1, 1)
f.form.SetBackgroundColor(style.DialogBgColor)
f.form.SetFieldBackgroundColor(style.BgColor)
f.form.SetFieldTextColor(style.FgColor)
f.form.SetButtonBackgroundColor(style.ButtonBgColor)
f.form.SetButtonTextColor(style.FgColor)
f.form.SetLabelColor(style.DialogFgColor)
f.form.SetButtonsAlign(tview.AlignCenter)
// 创建帮助文本
f.helpText = tview.NewTextView()
f.helpText.SetDynamicColors(true)
f.helpText.SetTextAlign(tview.AlignCenter)
f.helpText.SetTextColor(style.DialogFgColor)
f.helpText.SetBackgroundColor(style.DialogBgColor)
fmt.Fprintf(f.helpText,
"[%s::b]Tab[%s::b]: 导航 [%s::b]Enter[%s::b]: 选择 [%s::b]ESC[%s::b]: 返回",
style.GetColorHex(style.MenuBgColor), style.GetColorHex(style.PageHeaderFgColor),
style.GetColorHex(style.MenuBgColor), style.GetColorHex(style.PageHeaderFgColor),
style.GetColorHex(style.MenuBgColor), style.GetColorHex(style.PageHeaderFgColor),
)
// 创建布局
formLayout := tview.NewFlex().SetDirection(tview.FlexColumn)
formLayout.AddItem(EmptyBoxSpace(style.DialogBgColor), 1, 0, false)
formLayout.AddItem(f.form, 0, 1, true)
formLayout.AddItem(EmptyBoxSpace(style.DialogBgColor), 1, 0, false)
// 设置主布局
f.layout.SetTitle(fmt.Sprintf("[::b]%s", f.title))
f.layout.SetTitleColor(style.DialogFgColor)
f.layout.SetTitleAlign(tview.AlignCenter)
f.layout.SetBorder(true)
f.layout.SetBorderColor(style.DialogBorderColor)
f.layout.SetBackgroundColor(style.DialogBgColor)
// 添加表单区域
f.layout.AddItem(formLayout, 0, 1, true)
// 添加帮助文本区域
f.layout.AddItem(f.helpText, DialogHelpHeight, 0, false)
return f
}
// AddInputField adds an input field to the form.
func (f *Form) AddInputField(label, value string, fieldWidth int, accept func(textToCheck string, lastChar rune) bool, changed func(text string)) *Form {
// 存储字段信息
f.fields = append(f.fields, formField{
label: label,
value: value,
fieldWidth: fieldWidth,
})
// 添加输入字段到表单
f.form.AddInputField(label, value, fieldWidth, accept, changed)
// 更新表单尺寸
f.recalculateSize()
return f
}
// AddButton adds a button to the form.
func (f *Form) AddButton(label string, selected func()) *Form {
f.form.AddButton(label, selected)
// 更新表单尺寸
f.recalculateSize()
return f
}
// AddCheckbox adds a checkbox to the form.
func (f *Form) AddCheckbox(label string, checked bool, changed func(checked bool)) *Form {
f.form.AddCheckbox(label, checked, changed)
// 更新表单尺寸
f.recalculateSize()
return f
}
// SetCancelFunc sets the function to be called when the form is cancelled.
func (f *Form) SetCancelFunc(handler func()) *Form {
f.cancelHandler = handler
return f
}
// recalculateSize 重新计算表单尺寸
func (f *Form) recalculateSize() {
// 计算表单项数量
itemCount := f.form.GetFormItemCount()
// 计算高度 - 每个表单项占2行按钮区域至少占2行再加上边框和帮助文本
f.height = (itemCount * 2) + 2 + FormHeightOffset + DialogHelpHeight
// 计算宽度 - 类似于 submenu 的实现
maxLabelWidth := 0
maxValueWidth := 0
// 遍历所有字段,找出最长的标签和值
for _, field := range f.fields {
if len(field.label) > maxLabelWidth {
maxLabelWidth = len(field.label)
}
// 对于值,使用字段宽度和实际值长度中的较大者
valueWidth := field.fieldWidth
if len(field.value) > valueWidth {
valueWidth = len(field.value)
}
if valueWidth > maxValueWidth {
maxValueWidth = valueWidth
}
}
// 计算总宽度,类似于 submenu 的计算方式
f.width = maxLabelWidth + maxValueWidth + formWidthOffset
// 确保宽度不小于最小值
if f.width < DialogMinWidth {
f.width = DialogMinWidth
}
}
// Draw draws the form on the screen.
func (f *Form) Draw(screen tcell.Screen) {
// 在绘制前重新计算尺寸,确保尺寸是最新的
f.recalculateSize()
// 绘制
f.Box.DrawForSubclass(screen, f)
f.layout.Draw(screen)
}
// SetRect sets the position and size of the form.
func (f *Form) SetRect(x, y, width, height int) {
// 确保尺寸是最新的
f.recalculateSize()
// 类似于 submenu 的实现
ws := (width - f.width) / 2
hs := (height - f.height) / 2
// 确保不会超出屏幕
if f.width > width {
ws = 0
f.width = width - 1
}
if f.height > height {
hs = 0
f.height = height - 1
}
// 设置表单位置
f.Box.SetRect(x+ws, y+hs, f.width, f.height)
// 获取内部矩形并设置布局
x, y, width, height = f.Box.GetInnerRect()
f.layout.SetRect(x, y, width, height)
}
// Focus is called when this primitive receives focus.
func (f *Form) Focus(delegate func(p tview.Primitive)) {
// 确保表单获得焦点
if f.form != nil {
delegate(f.form)
} else {
// 如果表单为空则让Box获得焦点
delegate(f.Box)
}
}
// HasFocus returns whether or not this primitive has focus.
func (f *Form) HasFocus() bool {
return f.form.HasFocus()
}
// InputHandler returns the handler for this primitive.
func (f *Form) InputHandler() func(event *tcell.EventKey, setFocus func(p tview.Primitive)) {
return f.WrapInputHandler(func(event *tcell.EventKey, setFocus func(p tview.Primitive)) {
// ESC键处理
if event.Key() == tcell.KeyEscape && f.cancelHandler != nil {
f.cancelHandler()
return
}
// 将事件传递给表单
if handler := f.form.InputHandler(); handler != nil {
handler(event, setFocus)
}
})
}
// EmptyBoxSpace creates an empty box with the specified background color.
func EmptyBoxSpace(bgColor tcell.Color) *tview.Box {
box := tview.NewBox()
box.SetBackgroundColor(bgColor)
box.SetBorder(false)
return box
}

View File

@@ -15,13 +15,14 @@ const (
// InfoBarViewHeight info bar height.
const (
InfoBarViewHeight = 6
InfoBarViewHeight = 7
accountRow = 0
pidRow = 1
statusRow = 2
dataUsageRow = 3
workUsageRow = 4
httpServerRow = 5
statusRow = 1
platformRow = 2
sessionRow = 3
dataUsageRow = 4
workUsageRow = 5
httpServerRow = 6
// 列索引
labelCol1 = 0 // 第一列标签
@@ -43,7 +44,7 @@ func New() *InfoBar {
table := tview.NewTable()
headerColor := style.InfoBarItemFgColor
// Account 和 Version
// Account 和 PID
table.SetCell(
accountRow,
labelCol1,
@@ -54,26 +55,11 @@ func New() *InfoBar {
table.SetCell(
accountRow,
labelCol2,
tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Version:")),
tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "PID:")),
)
table.SetCell(accountRow, valueCol2, tview.NewTableCell(""))
// PID 和 ExePath 行
table.SetCell(
pidRow,
labelCol1,
tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "PID:")),
)
table.SetCell(pidRow, valueCol1, tview.NewTableCell(""))
table.SetCell(
pidRow,
labelCol2,
tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "ExePath:")),
)
table.SetCell(pidRow, valueCol2, tview.NewTableCell(""))
// Status 和 Key 行
// Status 和 ExePath 行
table.SetCell(
statusRow,
labelCol1,
@@ -84,10 +70,40 @@ func New() *InfoBar {
table.SetCell(
statusRow,
labelCol2,
tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Data Key:")),
tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "ExePath:")),
)
table.SetCell(statusRow, valueCol2, tview.NewTableCell(""))
// Platform 和 Version 行
table.SetCell(
platformRow,
labelCol1,
tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Platform:")),
)
table.SetCell(platformRow, valueCol1, tview.NewTableCell(""))
table.SetCell(
platformRow,
labelCol2,
tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Version:")),
)
table.SetCell(platformRow, valueCol2, tview.NewTableCell(""))
// Session 和 Data Key 行
table.SetCell(
sessionRow,
labelCol1,
tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Session:")),
)
table.SetCell(sessionRow, valueCol1, tview.NewTableCell(""))
table.SetCell(
sessionRow,
labelCol2,
tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Data Key:")),
)
table.SetCell(sessionRow, valueCol2, tview.NewTableCell(""))
// Data Usage 和 Data Dir 行
table.SetCell(
dataUsageRow,
@@ -126,6 +142,13 @@ func New() *InfoBar {
)
table.SetCell(httpServerRow, valueCol1, tview.NewTableCell(""))
table.SetCell(
httpServerRow,
labelCol2,
tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "Auto Decrypt:")),
)
table.SetCell(httpServerRow, valueCol2, tview.NewTableCell(""))
// infobar
infoBar := &InfoBar{
Box: tview.NewBox(),
@@ -141,17 +164,25 @@ func (info *InfoBar) UpdateAccount(account string) {
}
func (info *InfoBar) UpdateBasicInfo(pid int, version string, exePath string) {
info.table.GetCell(pidRow, valueCol1).SetText(fmt.Sprintf("%d", pid))
info.table.GetCell(pidRow, valueCol2).SetText(exePath)
info.table.GetCell(accountRow, valueCol2).SetText(version)
info.table.GetCell(accountRow, valueCol2).SetText(fmt.Sprintf("%d", pid))
info.table.GetCell(statusRow, valueCol2).SetText(exePath)
info.table.GetCell(platformRow, valueCol2).SetText(version)
}
func (info *InfoBar) UpdateStatus(status string) {
info.table.GetCell(statusRow, valueCol1).SetText(status)
}
func (info *InfoBar) UpdatePlatform(text string) {
info.table.GetCell(platformRow, valueCol1).SetText(text)
}
func (info *InfoBar) UpdateSession(text string) {
info.table.GetCell(sessionRow, valueCol1).SetText(text)
}
func (info *InfoBar) UpdateDataKey(key string) {
info.table.GetCell(statusRow, valueCol2).SetText(key)
info.table.GetCell(sessionRow, valueCol2).SetText(key)
}
func (info *InfoBar) UpdateDataUsageDir(dataUsage string, dataDir string) {
@@ -169,6 +200,11 @@ func (info *InfoBar) UpdateHTTPServer(server string) {
info.table.GetCell(httpServerRow, valueCol1).SetText(server)
}
// UpdateAutoDecrypt updates Auto Decrypt value.
func (info *InfoBar) UpdateAutoDecrypt(text string) {
info.table.GetCell(httpServerRow, valueCol2).SetText(text)
}
// Draw draws this primitive onto the screen.
func (info *InfoBar) Draw(screen tcell.Screen) {
info.Box.DrawForSubclass(screen, info)

View File

@@ -3,87 +3,127 @@ package darwinv3
import (
"context"
"crypto/md5"
"database/sql"
"encoding/hex"
"fmt"
"strings"
"time"
"github.com/fsnotify/fsnotify"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
"github.com/sjzar/chatlog/internal/errors"
"github.com/sjzar/chatlog/internal/model"
"github.com/sjzar/chatlog/pkg/util"
"github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm"
)
const (
MessageFilePattern = "^msg_([0-9]?[0-9])?\\.db$"
ContactFilePattern = "^wccontact_new2\\.db$"
ChatRoomFilePattern = "^group_new\\.db$"
SessionFilePattern = "^session_new\\.db$"
MediaFilePattern = "^hldata\\.db$"
Message = "message"
Contact = "contact"
ChatRoom = "chatroom"
Session = "session"
Media = "media"
)
type DataSource struct {
path string
messageDbs []*sql.DB
contactDb *sql.DB
chatRoomDb *sql.DB
sessionDb *sql.DB
mediaDb *sql.DB
var Groups = []dbm.Group{
{
Name: Message,
Pattern: `^msg_([0-9]?[0-9])?\.db$`,
BlackList: []string{},
},
{
Name: Contact,
Pattern: `^wccontact_new2\.db$`,
BlackList: []string{},
},
{
Name: ChatRoom,
Pattern: `group_new\.db$`,
BlackList: []string{},
},
{
Name: Session,
Pattern: `^session_new\.db$`,
BlackList: []string{},
},
{
Name: Media,
Pattern: `^hldata\.db$`,
BlackList: []string{},
},
}
talkerDBMap map[string]*sql.DB
type DataSource struct {
path string
dbm *dbm.DBManager
talkerDBMap map[string]string
user2DisplayName map[string]string
}
func New(path string) (*DataSource, error) {
ds := &DataSource{
path: path,
messageDbs: make([]*sql.DB, 0),
talkerDBMap: make(map[string]*sql.DB),
dbm: dbm.NewDBManager(path),
talkerDBMap: make(map[string]string),
user2DisplayName: make(map[string]string),
}
if err := ds.initMessageDbs(path); err != nil {
for _, g := range Groups {
ds.dbm.AddGroup(g)
}
if err := ds.dbm.Start(); err != nil {
return nil, err
}
if err := ds.initMessageDbs(); err != nil {
return nil, errors.DBInitFailed(err)
}
if err := ds.initContactDb(path); err != nil {
return nil, errors.DBInitFailed(err)
}
if err := ds.initChatRoomDb(path); err != nil {
return nil, errors.DBInitFailed(err)
}
if err := ds.initSessionDb(path); err != nil {
return nil, errors.DBInitFailed(err)
}
if err := ds.initMediaDb(path); err != nil {
if err := ds.initChatRoomDb(); err != nil {
return nil, errors.DBInitFailed(err)
}
ds.dbm.AddCallback(Message, func(event fsnotify.Event) error {
if !event.Op.Has(fsnotify.Create) {
return nil
}
if err := ds.initMessageDbs(); err != nil {
log.Err(err).Msgf("Failed to reinitialize message DBs: %s", event.Name)
}
return nil
})
ds.dbm.AddCallback(ChatRoom, func(event fsnotify.Event) error {
if !event.Op.Has(fsnotify.Create) {
return nil
}
if err := ds.initChatRoomDb(); err != nil {
log.Err(err).Msgf("Failed to reinitialize chatroom DB: %s", event.Name)
}
return nil
})
return ds, nil
}
func (ds *DataSource) initMessageDbs(path string) error {
func (ds *DataSource) SetCallback(name string, callback func(event fsnotify.Event) error) error {
return ds.dbm.AddCallback(name, callback)
}
files, err := util.FindFilesWithPatterns(path, MessageFilePattern, true)
func (ds *DataSource) initMessageDbs() error {
dbPaths, err := ds.dbm.GetDBPath(Message)
if err != nil {
return errors.DBFileNotFound(path, MessageFilePattern, err)
return err
}
if len(files) == 0 {
return errors.DBFileNotFound(path, MessageFilePattern, nil)
}
// 处理每个数据库文件
for _, filePath := range files {
// 连接数据库
db, err := sql.Open("sqlite3", filePath)
talkerDBMap := make(map[string]string)
for _, filePath := range dbPaths {
db, err := ds.dbm.OpenDB(filePath)
if err != nil {
log.Err(err).Msgf("连接数据库 %s 失败", filePath)
log.Err(err).Msgf("获取数据库 %s 失败", filePath)
continue
}
ds.messageDbs = append(ds.messageDbs, db)
// 获取所有表名
rows, err := db.Query("SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Chat_%'")
@@ -104,96 +144,42 @@ func (ds *DataSource) initMessageDbs(path string) error {
if talkerMd5 == "" {
continue
}
ds.talkerDBMap[talkerMd5] = db
talkerDBMap[talkerMd5] = filePath
}
rows.Close()
}
ds.talkerDBMap = talkerDBMap
return nil
}
func (ds *DataSource) initContactDb(path string) error {
files, err := util.FindFilesWithPatterns(path, ContactFilePattern, true)
func (ds *DataSource) initChatRoomDb() error {
db, err := ds.dbm.GetDB(ChatRoom)
if err != nil {
return errors.DBFileNotFound(path, ContactFilePattern, err)
return err
}
if len(files) == 0 {
return errors.DBFileNotFound(path, ContactFilePattern, nil)
}
ds.contactDb, err = sql.Open("sqlite3", files[0])
rows, err := db.Query("SELECT m_nsUsrName, IFNULL(nickname,\"\") FROM GroupMember")
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
return nil
}
func (ds *DataSource) initChatRoomDb(path string) error {
files, err := util.FindFilesWithPatterns(path, ChatRoomFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, ChatRoomFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, ChatRoomFilePattern, nil)
}
ds.chatRoomDb, err = sql.Open("sqlite3", files[0])
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
rows, err := ds.chatRoomDb.Query("SELECT m_nsUsrName, IFNULL(nickname,\"\") FROM GroupMember")
if err != nil {
log.Err(err).Msgf("数据库 %s 获取群聊成员失败", files[0])
log.Err(err).Msg("获取群聊成员失败")
return nil
}
user2DisplayName := make(map[string]string)
for rows.Next() {
var user string
var nickName string
if err := rows.Scan(&user, &nickName); err != nil {
log.Err(err).Msgf("数据库 %s 扫描表名失败", files[0])
log.Err(err).Msg("扫描表名失败")
continue
}
ds.user2DisplayName[user] = nickName
user2DisplayName[user] = nickName
}
rows.Close()
ds.user2DisplayName = user2DisplayName
return nil
}
func (ds *DataSource) initSessionDb(path string) error {
files, err := util.FindFilesWithPatterns(path, SessionFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, SessionFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, SessionFilePattern, nil)
}
ds.sessionDb, err = sql.Open("sqlite3", files[0])
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
return nil
}
func (ds *DataSource) initMediaDb(path string) error {
files, err := util.FindFilesWithPatterns(path, MediaFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, MediaFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, MediaFilePattern, nil)
}
ds.mediaDb, err = sql.Open("sqlite3", files[0])
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
return nil
}
// GetMessages 实现获取消息的方法
func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) {
// 在 darwinv3 中,每个联系人/群聊的消息存储在单独的表中,表名为 Chat_md5(talker)
@@ -204,10 +190,14 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
_talkerMd5Bytes := md5.Sum([]byte(talker))
talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:])
db, ok := ds.talkerDBMap[talkerMd5]
dbPath, ok := ds.talkerDBMap[talkerMd5]
if !ok {
return nil, errors.TalkerNotFound(talker)
}
db, err := ds.dbm.OpenDB(dbPath)
if err != nil {
return nil, err
}
tableName := fmt.Sprintf("Chat_%s", talkerMd5)
// 构建查询条件
@@ -297,7 +287,11 @@ func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset
}
// 执行查询
rows, err := ds.contactDb.QueryContext(ctx, query, args...)
db, err := ds.dbm.GetDB(Contact)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -351,7 +345,11 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse
}
// 执行查询
rows, err := ds.chatRoomDb.QueryContext(ctx, query, args...)
db, err := ds.dbm.GetDB(ChatRoom)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -380,7 +378,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse
contacts, err := ds.GetContacts(ctx, key, 1, 0)
if err == nil && len(contacts) > 0 && strings.HasSuffix(contacts[0].UserName, "@chatroom") {
// 再次尝试通过用户名查找群聊
rows, err := ds.chatRoomDb.QueryContext(ctx,
rows, err := db.QueryContext(ctx,
`SELECT IFNULL(m_nsUsrName,""), IFNULL(nickname,""), IFNULL(m_nsRemark,""), IFNULL(m_nsChatRoomMemList,""), IFNULL(m_nsChatRoomAdminList,"")
FROM GroupContact
WHERE m_nsUsrName = ?`,
@@ -448,7 +446,11 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset
}
// 执行查询
rows, err := ds.sessionDb.QueryContext(ctx, query, args...)
db, err := ds.dbm.GetDB(Session)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -506,7 +508,11 @@ WHERE
r.mediaMd5 = ?`
args := []interface{}{key}
// 执行查询
rows, err := ds.mediaDb.QueryContext(ctx, query, args...)
db, err := ds.dbm.GetDB(Media)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -541,46 +547,5 @@ WHERE
// Close 实现关闭数据库连接的方法
func (ds *DataSource) Close() error {
var errs []error
// 关闭消息数据库连接
for _, db := range ds.messageDbs {
if err := db.Close(); err != nil {
errs = append(errs, err)
}
}
// 关闭联系人数据库连接
if ds.contactDb != nil {
if err := ds.contactDb.Close(); err != nil {
errs = append(errs, err)
}
}
// 关闭群聊数据库连接
if ds.chatRoomDb != nil {
if err := ds.chatRoomDb.Close(); err != nil {
errs = append(errs, err)
}
}
// 关闭会话数据库连接
if ds.sessionDb != nil {
if err := ds.sessionDb.Close(); err != nil {
errs = append(errs, err)
}
}
// 关闭媒体数据库连接
if ds.mediaDb != nil {
if err := ds.mediaDb.Close(); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errors.DBCloseFailed(errs[0])
}
return nil
return ds.dbm.Close()
}

View File

@@ -4,6 +4,8 @@ import (
"context"
"time"
"github.com/fsnotify/fsnotify"
"github.com/sjzar/chatlog/internal/errors"
"github.com/sjzar/chatlog/internal/model"
"github.com/sjzar/chatlog/internal/wechatdb/datasource/darwinv3"
@@ -28,6 +30,9 @@ type DataSource interface {
// 媒体
GetMedia(ctx context.Context, _type string, key string) (*model.Media, error)
// 设置回调函数
SetCallback(name string, callback func(event fsnotify.Event) error) error
Close() error
}

View File

@@ -0,0 +1,170 @@
package dbm
import (
"database/sql"
"runtime"
"sync"
"time"
"github.com/fsnotify/fsnotify"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
"github.com/sjzar/chatlog/internal/errors"
"github.com/sjzar/chatlog/pkg/filecopy"
"github.com/sjzar/chatlog/pkg/filemonitor"
)
type DBManager struct {
path string
fm *filemonitor.FileMonitor
fgs map[string]*filemonitor.FileGroup
dbs map[string]*sql.DB
dbPaths map[string][]string
mutex sync.RWMutex
}
func NewDBManager(path string) *DBManager {
return &DBManager{
path: path,
fm: filemonitor.NewFileMonitor(),
fgs: make(map[string]*filemonitor.FileGroup),
dbs: make(map[string]*sql.DB),
dbPaths: make(map[string][]string),
}
}
func (d *DBManager) AddGroup(g Group) error {
fg, err := filemonitor.NewFileGroup(g.Name, d.path, g.Pattern, g.BlackList)
if err != nil {
return err
}
fg.AddCallback(d.Callback)
d.fm.AddGroup(fg)
d.mutex.Lock()
d.fgs[g.Name] = fg
d.mutex.Unlock()
return nil
}
func (d *DBManager) AddCallback(name string, callback func(event fsnotify.Event) error) error {
d.mutex.RLock()
fg, ok := d.fgs[name]
d.mutex.RUnlock()
if !ok {
return errors.FileGroupNotFound(name)
}
fg.AddCallback(callback)
return nil
}
func (d *DBManager) GetDB(name string) (*sql.DB, error) {
dbPaths, err := d.GetDBPath(name)
if err != nil {
return nil, err
}
return d.OpenDB(dbPaths[0])
}
func (d *DBManager) GetDBs(name string) ([]*sql.DB, error) {
dbPaths, err := d.GetDBPath(name)
if err != nil {
return nil, err
}
dbs := make([]*sql.DB, 0)
for _, file := range dbPaths {
db, err := d.OpenDB(file)
if err != nil {
return nil, err
}
dbs = append(dbs, db)
}
return dbs, nil
}
func (d *DBManager) GetDBPath(name string) ([]string, error) {
d.mutex.RLock()
dbPaths, ok := d.dbPaths[name]
d.mutex.RUnlock()
if !ok {
d.mutex.RLock()
fg, ok := d.fgs[name]
d.mutex.RUnlock()
if !ok {
return nil, errors.FileGroupNotFound(name)
}
list, err := fg.List()
if err != nil {
return nil, errors.DBFileNotFound(d.path, fg.PatternStr, err)
}
if len(list) == 0 {
return nil, errors.DBFileNotFound(d.path, fg.PatternStr, nil)
}
dbPaths = list
d.mutex.Lock()
d.dbPaths[name] = dbPaths
d.mutex.Unlock()
}
return dbPaths, nil
}
func (d *DBManager) OpenDB(path string) (*sql.DB, error) {
d.mutex.RLock()
db, ok := d.dbs[path]
d.mutex.RUnlock()
if ok {
return db, nil
}
var err error
tempPath := path
if runtime.GOOS == "windows" {
tempPath, err = filecopy.GetTempCopy(path)
if err != nil {
log.Err(err).Msgf("获取临时拷贝文件 %s 失败", path)
return nil, err
}
}
db, err = sql.Open("sqlite3", tempPath)
if err != nil {
log.Err(err).Msgf("连接数据库 %s 失败", path)
return nil, err
}
d.mutex.Lock()
d.dbs[path] = db
d.mutex.Unlock()
return db, nil
}
func (d *DBManager) Callback(event fsnotify.Event) error {
if !event.Op.Has(fsnotify.Create) {
return nil
}
d.mutex.Lock()
db, ok := d.dbs[event.Name]
if ok {
delete(d.dbs, event.Name)
go func(db *sql.DB) {
time.Sleep(time.Second * 5)
db.Close()
}(db)
}
d.mutex.Unlock()
return nil
}
func (d *DBManager) Start() error {
return d.fm.Start()
}
func (d *DBManager) Stop() error {
return d.fm.Stop()
}
func (d *DBManager) Close() error {
for _, db := range d.dbs {
db.Close()
}
return d.fm.Stop()
}

View File

@@ -0,0 +1,42 @@
package dbm
import (
"fmt"
"testing"
"time"
)
func TestXxx(t *testing.T) {
path := "/Users/sarv/Documents/chatlog/bigjun_9e7a"
g := Group{
Name: "session",
Pattern: `session\.db$`,
BlackList: []string{},
}
d := NewDBManager(path)
d.AddGroup(g)
d.Start()
i := 0
for {
db, err := d.GetDB("session")
if err != nil {
fmt.Println(err)
break
}
var username string
row := db.QueryRow(`SELECT username FROM SessionTable LIMIT 1`)
if err := row.Scan(&username); err != nil {
fmt.Printf("Error scanning row: %v\n", err)
time.Sleep(100 * time.Millisecond)
continue
}
fmt.Printf("%d: Username: %s\n", i, username)
i++
time.Sleep(1000 * time.Millisecond)
}
}

View File

@@ -0,0 +1,7 @@
package dbm
type Group struct {
Name string
Pattern string
BlackList []string
}

View File

@@ -10,22 +10,51 @@ import (
"strings"
"time"
"github.com/fsnotify/fsnotify"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
"github.com/sjzar/chatlog/internal/errors"
"github.com/sjzar/chatlog/internal/model"
"github.com/sjzar/chatlog/pkg/util"
"github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm"
)
const (
MessageFilePattern = "^message_([0-9]?[0-9])?\\.db$"
ContactFilePattern = "^contact\\.db$"
SessionFilePattern = "^session\\.db$"
MediaFilePattern = "^hardlink\\.db$"
VoiceFilePattern = "^media_([0-9]?[0-9])?\\.db$"
Message = "message"
Contact = "contact"
Session = "session"
Media = "media"
Voice = "voice"
)
var Groups = []dbm.Group{
{
Name: Message,
Pattern: `^message_([0-9]?[0-9])?\.db$`,
BlackList: []string{},
},
{
Name: Contact,
Pattern: `^contact\.db$`,
BlackList: []string{},
},
{
Name: Session,
Pattern: `session\.db$`,
BlackList: []string{},
},
{
Name: Media,
Pattern: `^hardlink\.db$`,
BlackList: []string{},
},
{
Name: Voice,
Pattern: `^media_([0-9]?[0-9])?\.db$`,
BlackList: []string{},
},
}
// MessageDBInfo 存储消息数据库的信息
type MessageDBInfo struct {
FilePath string
@@ -34,61 +63,65 @@ type MessageDBInfo struct {
}
type DataSource struct {
path string
messageDbs map[string]*sql.DB
contactDb *sql.DB
sessionDb *sql.DB
mediaDb *sql.DB
voiceDb []*sql.DB
path string
dbm *dbm.DBManager
// 消息数据库信息
messageFiles []MessageDBInfo
messageInfos []MessageDBInfo
}
func New(path string) (*DataSource, error) {
ds := &DataSource{
path: path,
messageDbs: make(map[string]*sql.DB),
voiceDb: make([]*sql.DB, 0),
messageFiles: make([]MessageDBInfo, 0),
dbm: dbm.NewDBManager(path),
messageInfos: make([]MessageDBInfo, 0),
}
if err := ds.initMessageDbs(path); err != nil {
return nil, errors.DBInitFailed(err)
}
if err := ds.initContactDb(path); err != nil {
return nil, errors.DBInitFailed(err)
}
if err := ds.initSessionDb(path); err != nil {
return nil, errors.DBInitFailed(err)
}
if err := ds.initMediaDb(path); err != nil {
return nil, errors.DBInitFailed(err)
}
if err := ds.initVoiceDb(path); err != nil {
for _, g := range Groups {
ds.dbm.AddGroup(g)
}
if err := ds.dbm.Start(); err != nil {
return nil, err
}
if err := ds.initMessageDbs(); err != nil {
return nil, errors.DBInitFailed(err)
}
ds.dbm.AddCallback(Message, func(event fsnotify.Event) error {
if !event.Op.Has(fsnotify.Create) {
return nil
}
if err := ds.initMessageDbs(); err != nil {
log.Err(err).Msgf("Failed to reinitialize message DBs: %s", event.Name)
}
return nil
})
return ds, nil
}
func (ds *DataSource) initMessageDbs(path string) error {
// 查找所有消息数据库文件
files, err := util.FindFilesWithPatterns(path, MessageFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, MessageFilePattern, err)
func (ds *DataSource) SetCallback(name string, callback func(event fsnotify.Event) error) error {
if name == "chatroom" {
name = Contact
}
return ds.dbm.AddCallback(name, callback)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, MessageFilePattern, nil)
func (ds *DataSource) initMessageDbs() error {
dbPaths, err := ds.dbm.GetDBPath(Message)
if err != nil {
return err
}
// 处理每个数据库文件
for _, filePath := range files {
// 连接数据库
db, err := sql.Open("sqlite3", filePath)
infos := make([]MessageDBInfo, 0)
for _, filePath := range dbPaths {
db, err := ds.dbm.OpenDB(filePath)
if err != nil {
log.Err(err).Msgf("连接数据库 %s 失败", filePath)
log.Err(err).Msgf("获取数据库 %s 失败", filePath)
continue
}
@@ -99,108 +132,38 @@ func (ds *DataSource) initMessageDbs(path string) error {
row := db.QueryRow("SELECT timestamp FROM Timestamp LIMIT 1")
if err := row.Scan(&timestamp); err != nil {
log.Err(err).Msgf("获取数据库 %s 的时间戳失败", filePath)
db.Close()
continue
}
startTime = time.Unix(timestamp, 0)
// 保存数据库信息
ds.messageFiles = append(ds.messageFiles, MessageDBInfo{
infos = append(infos, MessageDBInfo{
FilePath: filePath,
StartTime: startTime,
})
// 保存数据库连接
ds.messageDbs[filePath] = db
}
// 按照 StartTime 排序数据库文件
sort.Slice(ds.messageFiles, func(i, j int) bool {
return ds.messageFiles[i].StartTime.Before(ds.messageFiles[j].StartTime)
sort.Slice(infos, func(i, j int) bool {
return infos[i].StartTime.Before(infos[j].StartTime)
})
// 设置结束时间
for i := range ds.messageFiles {
if i == len(ds.messageFiles)-1 {
ds.messageFiles[i].EndTime = time.Now()
for i := range infos {
if i == len(infos)-1 {
infos[i].EndTime = time.Now()
} else {
ds.messageFiles[i].EndTime = ds.messageFiles[i+1].StartTime
infos[i].EndTime = infos[i+1].StartTime
}
}
return nil
}
func (ds *DataSource) initContactDb(path string) error {
files, err := util.FindFilesWithPatterns(path, ContactFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, ContactFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, ContactFilePattern, nil)
}
ds.contactDb, err = sql.Open("sqlite3", files[0])
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
return nil
}
func (ds *DataSource) initSessionDb(path string) error {
files, err := util.FindFilesWithPatterns(path, SessionFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, SessionFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, SessionFilePattern, nil)
}
ds.sessionDb, err = sql.Open("sqlite3", files[0])
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
return nil
}
func (ds *DataSource) initMediaDb(path string) error {
files, err := util.FindFilesWithPatterns(path, MediaFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, MediaFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, MediaFilePattern, nil)
}
ds.mediaDb, err = sql.Open("sqlite3", files[0])
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
return nil
}
func (ds *DataSource) initVoiceDb(path string) error {
files, err := util.FindFilesWithPatterns(path, VoiceFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, VoiceFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, VoiceFilePattern, nil)
}
for _, file := range files {
db, err := sql.Open("sqlite3", file)
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
ds.voiceDb = append(ds.voiceDb, db)
}
ds.messageInfos = infos
return nil
}
// getDBInfosForTimeRange 获取时间范围内的数据库信息
func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []MessageDBInfo {
var dbs []MessageDBInfo
for _, info := range ds.messageFiles {
for _, info := range ds.messageInfos {
if info.StartTime.Before(endTime) && info.EndTime.After(startTime) {
dbs = append(dbs, info)
}
@@ -234,8 +197,8 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
return nil, err
}
db, ok := ds.messageDbs[dbInfo.FilePath]
if !ok {
db, err := ds.dbm.OpenDB(dbInfo.FilePath)
if err != nil {
log.Error().Msgf("数据库 %s 未打开", dbInfo.FilePath)
continue
}
@@ -275,8 +238,8 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
// getMessagesSingleFile 从单个数据库文件获取消息
func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) {
db, ok := ds.messageDbs[dbInfo.FilePath]
if !ok {
db, err := ds.dbm.OpenDB(dbInfo.FilePath)
if err != nil {
return nil, errors.DBConnectFailed(dbInfo.FilePath, nil)
}
@@ -287,7 +250,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD
// 检查表是否存在
var exists bool
err := db.QueryRowContext(ctx,
err = db.QueryRowContext(ctx,
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
tableName).Scan(&exists)
@@ -445,7 +408,11 @@ func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset
}
// 执行查询
rows, err := ds.contactDb.QueryContext(ctx, query, args...)
db, err := ds.dbm.GetDB(Contact)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -477,13 +444,18 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse
var query string
var args []interface{}
// 执行查询
db, err := ds.dbm.GetDB(Contact)
if err != nil {
return nil, err
}
if key != "" {
// 按照关键字查询
query = `SELECT username, owner, ext_buffer FROM chat_room WHERE username = ?`
args = []interface{}{key}
// 执行查询
rows, err := ds.contactDb.QueryContext(ctx, query, args...)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -510,7 +482,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse
contacts, err := ds.GetContacts(ctx, key, 1, 0)
if err == nil && len(contacts) > 0 && strings.HasSuffix(contacts[0].UserName, "@chatroom") {
// 再次尝试通过用户名查找群聊
rows, err := ds.contactDb.QueryContext(ctx,
rows, err := db.QueryContext(ctx,
`SELECT username, owner, ext_buffer FROM chat_room WHERE username = ?`,
contacts[0].UserName)
@@ -560,7 +532,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse
}
// 执行查询
rows, err := ds.contactDb.QueryContext(ctx, query, args...)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -614,7 +586,11 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset
}
// 执行查询
rows, err := ds.sessionDb.QueryContext(ctx, query, args...)
db, err := ds.dbm.GetDB(Session)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -678,7 +654,12 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (*
query += " WHERE f.md5 = ? OR f.file_name LIKE ? || '%'"
args := []interface{}{key, key}
rows, err := ds.mediaDb.QueryContext(ctx, query, args...)
// 执行查询
db, err := ds.dbm.GetDB(Media)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -726,7 +707,12 @@ func (ds *DataSource) GetVoice(ctx context.Context, key string) (*model.Media, e
`
args := []interface{}{key}
for _, db := range ds.voiceDb {
dbs, err := ds.dbm.GetDBs(Voice)
if err != nil {
return nil, errors.DBConnectFailed("", err)
}
for _, db := range dbs {
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
@@ -755,38 +741,5 @@ func (ds *DataSource) GetVoice(ctx context.Context, key string) (*model.Media, e
}
func (ds *DataSource) Close() error {
var errs []error
// 关闭消息数据库连接
for _, db := range ds.messageDbs {
if err := db.Close(); err != nil {
errs = append(errs, err)
}
}
// 关闭联系人数据库连接
if ds.contactDb != nil {
if err := ds.contactDb.Close(); err != nil {
errs = append(errs, err)
}
}
// 关闭会话数据库连接
if ds.sessionDb != nil {
if err := ds.sessionDb.Close(); err != nil {
errs = append(errs, err)
}
}
if ds.mediaDb != nil {
if err := ds.mediaDb.Close(); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errors.DBCloseFailed(errs[0])
}
return nil
return ds.dbm.Close()
}

View File

@@ -9,23 +9,57 @@ import (
"strings"
"time"
"github.com/fsnotify/fsnotify"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
"github.com/sjzar/chatlog/internal/errors"
"github.com/sjzar/chatlog/internal/model"
"github.com/sjzar/chatlog/pkg/util"
"github.com/sjzar/chatlog/internal/wechatdb/datasource/dbm"
)
const (
MessageFilePattern = "^MSG([0-9]?[0-9])?\\.db$"
ContactFilePattern = "^MicroMsg.db$"
ImageFilePattern = "^HardLinkImage\\.db$"
VideoFilePattern = "^HardLinkVideo\\.db$"
FileFilePattern = "^HardLinkFile\\.db$"
VoiceFilePattern = "^MediaMSG([0-9])?\\.db$"
Message = "message"
Contact = "contact"
Image = "image"
Video = "video"
File = "file"
Voice = "voice"
)
var Groups = []dbm.Group{
{
Name: Message,
Pattern: `^MSG([0-9]?[0-9])?\.db$`,
BlackList: []string{},
},
{
Name: Contact,
Pattern: `^MicroMsg.db$`,
BlackList: []string{},
},
{
Name: Image,
Pattern: `^HardLinkImage\.db$`,
BlackList: []string{},
},
{
Name: Video,
Pattern: `^HardLinkVideo\.db$`,
BlackList: []string{},
},
{
Name: File,
Pattern: `^HardLinkFile\.db$`,
BlackList: []string{},
},
{
Name: Voice,
Pattern: `^MediaMSG([0-9])?\.db$`,
BlackList: []string{},
},
}
// MessageDBInfo 保存消息数据库的信息
type MessageDBInfo struct {
FilePath string
@@ -36,67 +70,67 @@ type MessageDBInfo struct {
// DataSource 实现了 DataSource 接口
type DataSource struct {
// 消息数据库
messageFiles []MessageDBInfo
messageDbs map[string]*sql.DB
path string
dbm *dbm.DBManager
// 联系人数据库
contactDbFile string
contactDb *sql.DB
imageDb *sql.DB
videoDb *sql.DB
fileDb *sql.DB
voiceDb []*sql.DB
// 消息数据库信息
messageInfos []MessageDBInfo
}
// New 创建一个新的 WindowsV3DataSource
func New(path string) (*DataSource, error) {
ds := &DataSource{
messageFiles: make([]MessageDBInfo, 0),
messageDbs: make(map[string]*sql.DB),
voiceDb: make([]*sql.DB, 0),
path: path,
dbm: dbm.NewDBManager(path),
messageInfos: make([]MessageDBInfo, 0),
}
// 初始化消息数据库
if err := ds.initMessageDbs(path); err != nil {
for _, g := range Groups {
ds.dbm.AddGroup(g)
}
if err := ds.dbm.Start(); err != nil {
return nil, err
}
if err := ds.initMessageDbs(); err != nil {
return nil, errors.DBInitFailed(err)
}
// 初始化联系人数据库
if err := ds.initContactDb(path); err != nil {
return nil, errors.DBInitFailed(err)
}
if err := ds.initMediaDb(path); err != nil {
return nil, errors.DBInitFailed(err)
}
if err := ds.initVoiceDb(path); err != nil {
return nil, errors.DBInitFailed(err)
}
ds.dbm.AddCallback(Message, func(event fsnotify.Event) error {
if !event.Op.Has(fsnotify.Create) {
return nil
}
if err := ds.initMessageDbs(); err != nil {
log.Err(err).Msgf("Failed to reinitialize message DBs: %s", event.Name)
}
return nil
})
return ds, nil
}
// initMessageDbs 初始化消息数据库
func (ds *DataSource) initMessageDbs(path string) error {
// 查找所有消息数据库文件
files, err := util.FindFilesWithPatterns(path, MessageFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, MessageFilePattern, err)
func (ds *DataSource) SetCallback(name string, callback func(event fsnotify.Event) error) error {
if name == "chatroom" {
name = Contact
}
return ds.dbm.AddCallback(name, callback)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, MessageFilePattern, nil)
// initMessageDbs 初始化消息数据库
func (ds *DataSource) initMessageDbs() error {
// 获取所有消息数据库文件路径
dbPaths, err := ds.dbm.GetDBPath(Message)
if err != nil {
return err
}
// 处理每个数据库文件
for _, filePath := range files {
// 连接数据库
db, err := sql.Open("sqlite3", filePath)
infos := make([]MessageDBInfo, 0)
for _, filePath := range dbPaths {
db, err := ds.dbm.OpenDB(filePath)
if err != nil {
log.Err(err).Msgf("连接数据库 %s 失败", filePath)
log.Err(err).Msgf("获取数据库 %s 失败", filePath)
continue
}
@@ -106,7 +140,6 @@ func (ds *DataSource) initMessageDbs(path string) error {
rows, err := db.Query("SELECT tableIndex, tableVersion, tableDesc FROM DBInfo")
if err != nil {
log.Err(err).Msgf("查询数据库 %s 的 DBInfo 表失败", filePath)
db.Close()
continue
}
@@ -133,7 +166,6 @@ func (ds *DataSource) initMessageDbs(path string) error {
rows, err = db.Query("SELECT UsrName FROM Name2ID")
if err != nil {
log.Err(err).Msgf("查询数据库 %s 的 Name2ID 表失败", filePath)
db.Close()
continue
}
@@ -150,123 +182,34 @@ func (ds *DataSource) initMessageDbs(path string) error {
rows.Close()
// 保存数据库信息
ds.messageFiles = append(ds.messageFiles, MessageDBInfo{
infos = append(infos, MessageDBInfo{
FilePath: filePath,
StartTime: startTime,
TalkerMap: talkerMap,
})
// 保存数据库连接
ds.messageDbs[filePath] = db
}
// 按照 StartTime 排序数据库文件
sort.Slice(ds.messageFiles, func(i, j int) bool {
return ds.messageFiles[i].StartTime.Before(ds.messageFiles[j].StartTime)
sort.Slice(infos, func(i, j int) bool {
return infos[i].StartTime.Before(infos[j].StartTime)
})
// 设置结束时间
for i := range ds.messageFiles {
if i == len(ds.messageFiles)-1 {
ds.messageFiles[i].EndTime = time.Now()
for i := range infos {
if i == len(infos)-1 {
infos[i].EndTime = time.Now()
} else {
ds.messageFiles[i].EndTime = ds.messageFiles[i+1].StartTime
infos[i].EndTime = infos[i+1].StartTime
}
}
return nil
}
// initContactDb 初始化联系人数据库
func (ds *DataSource) initContactDb(path string) error {
files, err := util.FindFilesWithPatterns(path, ContactFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, ContactFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, ContactFilePattern, nil)
}
ds.contactDbFile = files[0]
ds.contactDb, err = sql.Open("sqlite3", ds.contactDbFile)
if err != nil {
return errors.DBConnectFailed(ds.contactDbFile, err)
}
return nil
}
// initContactDb 初始化联系人数据库
func (ds *DataSource) initMediaDb(path string) error {
files, err := util.FindFilesWithPatterns(path, ImageFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, ImageFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, ImageFilePattern, nil)
}
ds.imageDb, err = sql.Open("sqlite3", files[0])
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
files, err = util.FindFilesWithPatterns(path, VideoFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, VideoFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, VideoFilePattern, nil)
}
ds.videoDb, err = sql.Open("sqlite3", files[0])
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
files, err = util.FindFilesWithPatterns(path, FileFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, FileFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, FileFilePattern, nil)
}
ds.fileDb, err = sql.Open("sqlite3", files[0])
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
return nil
}
func (ds *DataSource) initVoiceDb(path string) error {
files, err := util.FindFilesWithPatterns(path, VoiceFilePattern, true)
if err != nil {
return errors.DBFileNotFound(path, VoiceFilePattern, err)
}
if len(files) == 0 {
return errors.DBFileNotFound(path, VoiceFilePattern, nil)
}
for _, file := range files {
db, err := sql.Open("sqlite3", file)
if err != nil {
return errors.DBConnectFailed(files[0], err)
}
ds.voiceDb = append(ds.voiceDb, db)
}
ds.messageInfos = infos
return nil
}
// getDBInfosForTimeRange 获取时间范围内的数据库信息
func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []MessageDBInfo {
var dbs []MessageDBInfo
for _, info := range ds.messageFiles {
for _, info := range ds.messageInfos {
if info.StartTime.Before(endTime) && info.EndTime.After(startTime) {
dbs = append(dbs, info)
}
@@ -296,70 +239,19 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
return nil, err
}
db, ok := ds.messageDbs[dbInfo.FilePath]
if !ok {
db, err := ds.dbm.OpenDB(dbInfo.FilePath)
if err != nil {
log.Error().Msgf("数据库 %s 未打开", dbInfo.FilePath)
continue
}
// 构建查询条件
conditions := []string{"Sequence >= ? AND Sequence <= ?"}
args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000}
if len(talker) > 0 {
talkerID, ok := dbInfo.TalkerMap[talker]
if ok {
conditions = append(conditions, "TalkerId = ?")
args = append(args, talkerID)
} else {
conditions = append(conditions, "StrTalker = ?")
args = append(args, talker)
}
}
query := fmt.Sprintf(`
SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender,
Type, SubType, StrContent, CompressContent, BytesExtra
FROM MSG
WHERE %s
ORDER BY Sequence ASC
`, strings.Join(conditions, " AND "))
// 执行查询
rows, err := db.QueryContext(ctx, query, args...)
messages, err := ds.getMessagesFromDB(ctx, db, dbInfo, startTime, endTime, talker)
if err != nil {
log.Err(err).Msgf("查询数据库 %s 失败", dbInfo.FilePath)
log.Err(err).Msgf("数据库 %s 获取消息失败", dbInfo.FilePath)
continue
}
// 处理查询结果
for rows.Next() {
var msg model.MessageV3
var compressContent []byte
var bytesExtra []byte
err := rows.Scan(
&msg.MsgSvrID,
&msg.Sequence,
&msg.CreateTime,
&msg.StrTalker,
&msg.IsSender,
&msg.Type,
&msg.SubType,
&msg.StrContent,
&compressContent,
&bytesExtra,
)
if err != nil {
log.Err(err).Msg("扫描消息行失败")
continue
}
msg.CompressContent = compressContent
msg.BytesExtra = bytesExtra
totalMessages = append(totalMessages, msg.Wrap())
}
rows.Close()
totalMessages = append(totalMessages, messages...)
if limit+offset > 0 && len(totalMessages) >= limit+offset {
break
@@ -388,6 +280,11 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T
// getMessagesSingleFile 从单个数据库文件获取消息
func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) {
db, err := ds.dbm.OpenDB(dbInfo.FilePath)
if err != nil {
return nil, errors.DBConnectFailed(dbInfo.FilePath, nil)
}
// 构建查询条件
conditions := []string{"Sequence >= ? AND Sequence <= ?"}
args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000}
@@ -419,7 +316,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD
}
// 执行查询
rows, err := ds.messageDbs[dbInfo.FilePath].QueryContext(ctx, query, args...)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -453,6 +350,69 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD
return totalMessages, nil
}
// getMessagesFromDB 从数据库获取消息
func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string) ([]*model.Message, error) {
// 构建查询条件
conditions := []string{"Sequence >= ? AND Sequence <= ?"}
args := []interface{}{startTime.Unix() * 1000, endTime.Unix() * 1000}
if len(talker) > 0 {
talkerID, ok := dbInfo.TalkerMap[talker]
if ok {
conditions = append(conditions, "TalkerId = ?")
args = append(args, talkerID)
} else {
conditions = append(conditions, "StrTalker = ?")
args = append(args, talker)
}
}
query := fmt.Sprintf(`
SELECT MsgSvrID, Sequence, CreateTime, StrTalker, IsSender,
Type, SubType, StrContent, CompressContent, BytesExtra
FROM MSG
WHERE %s
ORDER BY Sequence ASC
`, strings.Join(conditions, " AND "))
// 执行查询
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
defer rows.Close()
// 处理查询结果
messages := []*model.Message{}
for rows.Next() {
var msg model.MessageV3
var compressContent []byte
var bytesExtra []byte
err := rows.Scan(
&msg.MsgSvrID,
&msg.Sequence,
&msg.CreateTime,
&msg.StrTalker,
&msg.IsSender,
&msg.Type,
&msg.SubType,
&msg.StrContent,
&compressContent,
&bytesExtra,
)
if err != nil {
return nil, errors.ScanRowFailed(err)
}
msg.CompressContent = compressContent
msg.BytesExtra = bytesExtra
messages = append(messages, msg.Wrap())
}
return messages, nil
}
// GetContacts 实现获取联系人信息的方法
func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset int) ([]*model.Contact, error) {
var query string
@@ -478,7 +438,11 @@ func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset
}
// 执行查询
rows, err := ds.contactDb.QueryContext(ctx, query, args...)
db, err := ds.dbm.GetDB(Contact)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -516,7 +480,11 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse
args = []interface{}{key}
// 执行查询
rows, err := ds.contactDb.QueryContext(ctx, query, args...)
db, err := ds.dbm.GetDB(Contact)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -543,7 +511,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse
contacts, err := ds.GetContacts(ctx, key, 1, 0)
if err == nil && len(contacts) > 0 && strings.HasSuffix(contacts[0].UserName, "@chatroom") {
// 再次尝试通过用户名查找群聊
rows, err := ds.contactDb.QueryContext(ctx,
rows, err := db.QueryContext(ctx,
`SELECT ChatRoomName, Reserved2, RoomData FROM ChatRoom WHERE ChatRoomName = ?`,
contacts[0].UserName)
@@ -593,7 +561,11 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse
}
// 执行查询
rows, err := ds.contactDb.QueryContext(ctx, query, args...)
db, err := ds.dbm.GetDB(Contact)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -647,7 +619,11 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset
}
// 执行查询
rows, err := ds.contactDb.QueryContext(ctx, query, args...)
db, err := ds.dbm.GetDB(Contact)
if err != nil {
return nil, err
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
}
@@ -688,25 +664,29 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (*
return nil, errors.DecodeKeyFailed(err)
}
var db *sql.DB
var dbType string
var table1, table2 string
switch _type {
case "image":
db = ds.imageDb
dbType = Image
table1 = "HardLinkImageAttribute"
table2 = "HardLinkImageID"
case "video":
db = ds.videoDb
dbType = Video
table1 = "HardLinkVideoAttribute"
table2 = "HardLinkVideoID"
case "file":
db = ds.fileDb
dbType = File
table1 = "HardLinkFileAttribute"
table2 = "HardLinkFileID"
default:
return nil, errors.MediaTypeUnsupported(_type)
}
db, err := ds.dbm.GetDB(dbType)
if err != nil {
return nil, err
}
query := fmt.Sprintf(`
@@ -768,7 +748,12 @@ func (ds *DataSource) GetVoice(ctx context.Context, key string) (*model.Media, e
`
args := []interface{}{key}
for _, db := range ds.voiceDb {
dbs, err := ds.dbm.GetDBs(Voice)
if err != nil {
return nil, errors.DBConnectFailed("", err)
}
for _, db := range dbs {
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.QueryFailed(query, err)
@@ -798,41 +783,5 @@ func (ds *DataSource) GetVoice(ctx context.Context, key string) (*model.Media, e
// Close 实现 DataSource 接口的 Close 方法
func (ds *DataSource) Close() error {
var errs []error
// 关闭消息数据库连接
for _, db := range ds.messageDbs {
if err := db.Close(); err != nil {
errs = append(errs, err)
}
}
// 关闭联系人数据库连接
if ds.contactDb != nil {
if err := ds.contactDb.Close(); err != nil {
errs = append(errs, err)
}
}
if ds.imageDb != nil {
if err := ds.imageDb.Close(); err != nil {
errs = append(errs, err)
}
}
if ds.videoDb != nil {
if err := ds.videoDb.Close(); err != nil {
errs = append(errs, err)
}
}
if ds.fileDb != nil {
if err := ds.fileDb.Close(); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errors.DBCloseFailed(errs[0])
}
return nil
return ds.dbm.Close()
}

View File

@@ -3,6 +3,9 @@ package repository
import (
"context"
"github.com/fsnotify/fsnotify"
"github.com/rs/zerolog/log"
"github.com/sjzar/chatlog/internal/errors"
"github.com/sjzar/chatlog/internal/model"
"github.com/sjzar/chatlog/internal/wechatdb/datasource"
@@ -61,6 +64,9 @@ func New(ds datasource.DataSource) (*Repository, error) {
return nil, errors.InitCacheFailed(err)
}
ds.SetCallback("contact", r.contactCallback)
ds.SetCallback("chatroom", r.chatroomCallback)
return r, nil
}
@@ -79,6 +85,26 @@ func (r *Repository) initCache(ctx context.Context) error {
return nil
}
func (r *Repository) contactCallback(event fsnotify.Event) error {
if !event.Op.Has(fsnotify.Create) {
return nil
}
if err := r.initContactCache(context.Background()); err != nil {
log.Err(err).Msgf("Failed to reinitialize contact cache: %s", event.Name)
}
return nil
}
func (r *Repository) chatroomCallback(event fsnotify.Event) error {
if !event.Op.Has(fsnotify.Create) {
return nil
}
if err := r.initChatRoomCache(context.Background()); err != nil {
log.Err(err).Msgf("Failed to reinitialize contact cache: %s", event.Name)
}
return nil
}
// Close 实现 Repository 接口的 Close 方法
func (r *Repository) Close() error {
return r.ds.Close()